Skip to content

Commit

Permalink
Use Optional<ExchangeHandle> variable instead of ExchangeContext in P…
Browse files Browse the repository at this point in the history
…airingSession.

This ensures that the ExchangeContext for the session is automatically
reference counted without explicit manual Retain() and Release() calls.
As a result, the ExchangeContext is held until PairingSession::Clear()
gets called that, in turn, calls ClearValue() on the Optional to
internally call Release() on the underlying target.

Fixes Issue project-chip#32498.
  • Loading branch information
pidarped committed Mar 8, 2024
1 parent b742587 commit 912adec
Show file tree
Hide file tree
Showing 4 changed files with 63 additions and 59 deletions.
48 changes: 25 additions & 23 deletions src/protocols/secure_channel/CASESession.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -514,7 +514,8 @@ CHIP_ERROR CASESession::EstablishSession(SessionManager & sessionManager, Fabric

// We are setting the exchange context specifically before checking for error.
// This is to make sure the exchange will get closed if Init() returned an error.
mExchangeCtxt = exchangeCtxt;
ExchangeHandle ecHandle(*exchangeCtxt);
mExchangeCtxt.SetValue(ecHandle);

// From here onwards, let's go to exit on error, as some state might have already
// been initialized
Expand All @@ -527,7 +528,7 @@ CHIP_ERROR CASESession::EstablishSession(SessionManager & sessionManager, Fabric
mSessionResumptionStorage = sessionResumptionStorage;
mLocalMRPConfig = mrpLocalConfig.ValueOr(GetDefaultMRPConfig());

mExchangeCtxt->UseSuggestedResponseTimeout(kExpectedSigma1ProcessingTime);
mExchangeCtxt.Value()->UseSuggestedResponseTimeout(kExpectedSigma1ProcessingTime);
mPeerNodeId = peerScopedNodeId.GetNodeId();
mLocalNodeId = fabricInfo->GetNodeId();

Expand All @@ -549,7 +550,7 @@ void CASESession::OnResponseTimeout(ExchangeContext * ec)
{
MATTER_TRACE_SCOPE("OnResponseTimeout", "CASESession");
VerifyOrReturn(ec != nullptr, ChipLogError(SecureChannel, "CASESession::OnResponseTimeout was called by null exchange"));
VerifyOrReturn(mExchangeCtxt == ec, ChipLogError(SecureChannel, "CASESession::OnResponseTimeout exchange doesn't match"));
VerifyOrReturn(&mExchangeCtxt.Value().Get() == ec, ChipLogError(SecureChannel, "CASESession::OnResponseTimeout exchange doesn't match"));
ChipLogError(SecureChannel, "CASESession timed out while waiting for a response from the peer. Current state was %u",
to_underlying(mState));
MATTER_TRACE_COUNTER("CASETimeout");
Expand Down Expand Up @@ -735,8 +736,8 @@ CHIP_ERROR CASESession::SendSigma1()
ReturnErrorOnFailure(mCommissioningHash.AddData(ByteSpan{ msg_R1->Start(), msg_R1->DataLength() }));

// Call delegate to send the msg to peer
ReturnErrorOnFailure(mExchangeCtxt->SendMessage(Protocols::SecureChannel::MsgType::CASE_Sigma1, std::move(msg_R1),
SendFlags(SendMessageFlags::kExpectResponse)));
ReturnErrorOnFailure(mExchangeCtxt.Value()->SendMessage(Protocols::SecureChannel::MsgType::CASE_Sigma1, std::move(msg_R1),
SendFlags(SendMessageFlags::kExpectResponse)));

mState = resuming ? State::kSentSigma1Resume : State::kSentSigma1;

Expand Down Expand Up @@ -959,8 +960,8 @@ CHIP_ERROR CASESession::SendSigma2Resume()
ReturnErrorOnFailure(tlvWriter.Finalize(&msg_R2_resume));

// Call delegate to send the msg to peer
ReturnErrorOnFailure(mExchangeCtxt->SendMessage(Protocols::SecureChannel::MsgType::CASE_Sigma2Resume, std::move(msg_R2_resume),
SendFlags(SendMessageFlags::kExpectResponse)));
ReturnErrorOnFailure(mExchangeCtxt.Value()->SendMessage(Protocols::SecureChannel::MsgType::CASE_Sigma2Resume, std::move(msg_R2_resume),
SendFlags(SendMessageFlags::kExpectResponse)));

mState = State::kSentSigma2Resume;

Expand Down Expand Up @@ -1096,8 +1097,8 @@ CHIP_ERROR CASESession::SendSigma2()
ReturnErrorOnFailure(mCommissioningHash.AddData(ByteSpan{ msg_R2->Start(), msg_R2->DataLength() }));

// Call delegate to send the msg to peer
ReturnErrorOnFailure(mExchangeCtxt->SendMessage(Protocols::SecureChannel::MsgType::CASE_Sigma2, std::move(msg_R2),
SendFlags(SendMessageFlags::kExpectResponse)));
ReturnErrorOnFailure(mExchangeCtxt.Value()->SendMessage(Protocols::SecureChannel::MsgType::CASE_Sigma2, std::move(msg_R2),
SendFlags(SendMessageFlags::kExpectResponse)));

mState = State::kSentSigma2;

Expand Down Expand Up @@ -1148,7 +1149,7 @@ CHIP_ERROR CASESession::HandleSigma2Resume(System::PacketBufferHandle && msg)
if (tlvReader.Next() != CHIP_END_OF_TLV)
{
SuccessOrExit(err = DecodeMRPParametersIfPresent(TLV::ContextTag(4), tlvReader));
mExchangeCtxt->GetSessionHandle()->AsUnauthenticatedSession()->SetRemoteSessionParameters(GetRemoteSessionParameters());
mExchangeCtxt.Value()->GetSessionHandle()->AsUnauthenticatedSession()->SetRemoteSessionParameters(GetRemoteSessionParameters());
}

ChipLogDetail(SecureChannel, "Peer assigned session session ID %d", responderSessionId);
Expand Down Expand Up @@ -1341,7 +1342,7 @@ CHIP_ERROR CASESession::HandleSigma2(System::PacketBufferHandle && msg)
if (tlvReader.Next() != CHIP_END_OF_TLV)
{
SuccessOrExit(err = DecodeMRPParametersIfPresent(TLV::ContextTag(kTag_Sigma2_ResponderMRPParams), tlvReader));
mExchangeCtxt->GetSessionHandle()->AsUnauthenticatedSession()->SetRemoteSessionParameters(GetRemoteSessionParameters());
mExchangeCtxt.Value()->GetSessionHandle()->AsUnauthenticatedSession()->SetRemoteSessionParameters(GetRemoteSessionParameters());
}

exit:
Expand Down Expand Up @@ -1410,7 +1411,7 @@ CHIP_ERROR CASESession::SendSigma3a()
{
SuccessOrExit(err = helper->ScheduleWork());
mSendSigma3Helper = helper;
mExchangeCtxt->WillSendMessage();
mExchangeCtxt.Value()->WillSendMessage();
mState = State::kSendSigma3Pending;
}
else
Expand Down Expand Up @@ -1537,8 +1538,8 @@ CHIP_ERROR CASESession::SendSigma3c(SendSigma3Data & data, CHIP_ERROR status)
SuccessOrExit(err);

// Call delegate to send the Msg3 to peer
err = mExchangeCtxt->SendMessage(Protocols::SecureChannel::MsgType::CASE_Sigma3, std::move(msg_R3),
SendFlags(SendMessageFlags::kExpectResponse));
err = mExchangeCtxt.Value()->SendMessage(Protocols::SecureChannel::MsgType::CASE_Sigma3, std::move(msg_R3),
SendFlags(SendMessageFlags::kExpectResponse));
SuccessOrExit(err);

ChipLogProgress(SecureChannel, "Sent Sigma3 msg");
Expand Down Expand Up @@ -1704,7 +1705,7 @@ CHIP_ERROR CASESession::HandleSigma3a(System::PacketBufferHandle && msg)

SuccessOrExit(err = helper->ScheduleWork());
mHandleSigma3Helper = helper;
mExchangeCtxt->WillSendMessage();
mExchangeCtxt.Value()->WillSendMessage();
mState = State::kHandleSigma3Pending;
}

Expand Down Expand Up @@ -2031,7 +2032,7 @@ CHIP_ERROR CASESession::ParseSigma1(TLV::ContiguousBufferTLVReader & tlvReader,
if (err == CHIP_NO_ERROR && tlvReader.GetTag() == ContextTag(kInitiatorMRPParamsTag))
{
ReturnErrorOnFailure(DecodeMRPParametersIfPresent(TLV::ContextTag(kInitiatorMRPParamsTag), tlvReader));
mExchangeCtxt->GetSessionHandle()->AsUnauthenticatedSession()->SetRemoteSessionParameters(GetRemoteSessionParameters());
mExchangeCtxt.Value()->GetSessionHandle()->AsUnauthenticatedSession()->SetRemoteSessionParameters(GetRemoteSessionParameters());
err = tlvReader.Next();
}

Expand Down Expand Up @@ -2079,26 +2080,27 @@ CHIP_ERROR CASESession::ParseSigma1(TLV::ContiguousBufferTLVReader & tlvReader,
return CHIP_NO_ERROR;
}

CHIP_ERROR CASESession::ValidateReceivedMessage(ExchangeContext * ec, const PayloadHeader & payloadHeader,
CHIP_ERROR CASESession::ValidateReceivedMessage(chip::Messaging::ExchangeContext * ec, const PayloadHeader & payloadHeader,
const System::PacketBufferHandle & msg)
{
VerifyOrReturnError(ec != nullptr, CHIP_ERROR_INVALID_ARGUMENT);

// mExchangeCtxt can be nullptr if this is the first message (CASE_Sigma1) received by CASESession
// via UnsolicitedMessageHandler. The exchange context is allocated by exchange manager and provided
// to the handler (CASESession object).
if (mExchangeCtxt != nullptr)
if (mExchangeCtxt.HasValue())
{
if (mExchangeCtxt != ec)
if (&mExchangeCtxt.Value().Get() != ec)
{
ReturnErrorOnFailure(CHIP_ERROR_INVALID_ARGUMENT);
}
}
else
{
mExchangeCtxt = ec;
ExchangeHandle ecHandle(*ec);
mExchangeCtxt.SetValue(ecHandle);
}
mExchangeCtxt->UseSuggestedResponseTimeout(kExpectedHighProcessingTime);
mExchangeCtxt.Value()->UseSuggestedResponseTimeout(kExpectedHighProcessingTime);

VerifyOrReturnError(!msg.IsNull(), CHIP_ERROR_INVALID_ARGUMENT);
return CHIP_NO_ERROR;
Expand All @@ -2123,7 +2125,7 @@ CHIP_ERROR CASESession::OnMessageReceived(ExchangeContext * ec, const PayloadHea
//
// Should you need to resume the CASESession, you could theoretically pass along the msg to a callback that gets
// registered when setting mStopHandshakeAtState.
mExchangeCtxt->WillSendMessage();
mExchangeCtxt.Value()->WillSendMessage();
return CHIP_NO_ERROR;
}
#endif // CONFIG_BUILD_FOR_HOST_UNIT_TEST
Expand All @@ -2133,7 +2135,7 @@ CHIP_ERROR CASESession::OnMessageReceived(ExchangeContext * ec, const PayloadHea
msgType == Protocols::SecureChannel::MsgType::CASE_Sigma2Resume ||
msgType == Protocols::SecureChannel::MsgType::CASE_Sigma3)
{
SuccessOrExit(err = mExchangeCtxt->FlushAcks());
SuccessOrExit(err = mExchangeCtxt.Value()->FlushAcks());
}
#endif // CHIP_CONFIG_SLOW_CRYPTO

Expand Down
38 changes: 20 additions & 18 deletions src/protocols/secure_channel/PASESession.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -213,17 +213,18 @@ CHIP_ERROR PASESession::Pair(SessionManager & sessionManager, uint32_t peerSetUp
{
MATTER_TRACE_SCOPE("Pair", "PASESession");
ReturnErrorCodeIf(exchangeCtxt == nullptr, CHIP_ERROR_INVALID_ARGUMENT);
ExchangeHandle ecHandle(*exchangeCtxt);
CHIP_ERROR err = Init(sessionManager, peerSetUpPINCode, delegate);
SuccessOrExit(err);

mRole = CryptoContext::SessionRole::kInitiator;

mExchangeCtxt = exchangeCtxt;
mExchangeCtxt.SetValue(ecHandle);

// When commissioning starts, the peer is assumed to be active.
mExchangeCtxt->GetSessionHandle()->AsUnauthenticatedSession()->MarkActiveRx();
mExchangeCtxt.Value()->GetSessionHandle()->AsUnauthenticatedSession()->MarkActiveRx();

mExchangeCtxt->UseSuggestedResponseTimeout(kExpectedLowProcessingTime);
mExchangeCtxt.Value()->UseSuggestedResponseTimeout(kExpectedLowProcessingTime);

mLocalMRPConfig = mrpLocalConfig.ValueOr(GetDefaultMRPConfig());

Expand All @@ -244,7 +245,7 @@ void PASESession::OnResponseTimeout(ExchangeContext * ec)
{
MATTER_TRACE_SCOPE("OnResponseTimeout", "PASESession");
VerifyOrReturn(ec != nullptr, ChipLogError(SecureChannel, "PASESession::OnResponseTimeout was called by null exchange"));
VerifyOrReturn(mExchangeCtxt == nullptr || mExchangeCtxt == ec,
VerifyOrReturn(!mExchangeCtxt.HasValue() || &mExchangeCtxt.Value().Get() == ec,
ChipLogError(SecureChannel, "PASESession::OnResponseTimeout exchange doesn't match"));
// If we were waiting for something, mNextExpectedMsg had better have a value.
ChipLogError(SecureChannel, "PASESession timed out while waiting for a response from the peer. Expected message type was %u",
Expand Down Expand Up @@ -309,7 +310,7 @@ CHIP_ERROR PASESession::SendPBKDFParamRequest()
ReturnErrorOnFailure(mCommissioningHash.AddData(ByteSpan{ req->Start(), req->DataLength() }));

ReturnErrorOnFailure(
mExchangeCtxt->SendMessage(MsgType::PBKDFParamRequest, std::move(req), SendFlags(SendMessageFlags::kExpectResponse)));
mExchangeCtxt.Value()->SendMessage(MsgType::PBKDFParamRequest, std::move(req), SendFlags(SendMessageFlags::kExpectResponse)));

mNextExpectedMsg.SetValue(MsgType::PBKDFParamResponse);

Expand Down Expand Up @@ -364,7 +365,7 @@ CHIP_ERROR PASESession::HandlePBKDFParamRequest(System::PacketBufferHandle && ms
if (tlvReader.Next() != CHIP_END_OF_TLV)
{
SuccessOrExit(err = DecodeMRPParametersIfPresent(TLV::ContextTag(5), tlvReader));
mExchangeCtxt->GetSessionHandle()->AsUnauthenticatedSession()->SetRemoteSessionParameters(GetRemoteSessionParameters());
mExchangeCtxt.Value()->GetSessionHandle()->AsUnauthenticatedSession()->SetRemoteSessionParameters(GetRemoteSessionParameters());
}

err = SendPBKDFParamResponse(ByteSpan(initiatorRandom), hasPBKDFParameters);
Expand Down Expand Up @@ -429,7 +430,7 @@ CHIP_ERROR PASESession::SendPBKDFParamResponse(ByteSpan initiatorRandom, bool in
ReturnErrorOnFailure(SetupSpake2p());

ReturnErrorOnFailure(
mExchangeCtxt->SendMessage(MsgType::PBKDFParamResponse, std::move(resp), SendFlags(SendMessageFlags::kExpectResponse)));
mExchangeCtxt.Value()->SendMessage(MsgType::PBKDFParamResponse, std::move(resp), SendFlags(SendMessageFlags::kExpectResponse)));
ChipLogDetail(SecureChannel, "Sent PBKDF param response");

mNextExpectedMsg.SetValue(MsgType::PASE_Pake1);
Expand Down Expand Up @@ -483,7 +484,7 @@ CHIP_ERROR PASESession::HandlePBKDFParamResponse(System::PacketBufferHandle && m
if (tlvReader.Next() != CHIP_END_OF_TLV)
{
SuccessOrExit(err = DecodeMRPParametersIfPresent(TLV::ContextTag(5), tlvReader));
mExchangeCtxt->GetSessionHandle()->AsUnauthenticatedSession()->SetRemoteSessionParameters(GetRemoteSessionParameters());
mExchangeCtxt.Value()->GetSessionHandle()->AsUnauthenticatedSession()->SetRemoteSessionParameters(GetRemoteSessionParameters());
}

// TODO - Add a unit test that exercises mHavePBKDFParameters path
Expand All @@ -508,7 +509,7 @@ CHIP_ERROR PASESession::HandlePBKDFParamResponse(System::PacketBufferHandle && m
if (tlvReader.Next() != CHIP_END_OF_TLV)
{
SuccessOrExit(err = DecodeMRPParametersIfPresent(TLV::ContextTag(5), tlvReader));
mExchangeCtxt->GetSessionHandle()->AsUnauthenticatedSession()->SetRemoteSessionParameters(GetRemoteSessionParameters());
mExchangeCtxt.Value()->GetSessionHandle()->AsUnauthenticatedSession()->SetRemoteSessionParameters(GetRemoteSessionParameters());
}
}

Expand Down Expand Up @@ -558,7 +559,7 @@ CHIP_ERROR PASESession::SendMsg1()
ReturnErrorOnFailure(tlvWriter.Finalize(&msg));

ReturnErrorOnFailure(
mExchangeCtxt->SendMessage(MsgType::PASE_Pake1, std::move(msg), SendFlags(SendMessageFlags::kExpectResponse)));
mExchangeCtxt.Value()->SendMessage(MsgType::PASE_Pake1, std::move(msg), SendFlags(SendMessageFlags::kExpectResponse)));
ChipLogDetail(SecureChannel, "Sent spake2p msg1");

mNextExpectedMsg.SetValue(MsgType::PASE_Pake2);
Expand Down Expand Up @@ -620,7 +621,7 @@ CHIP_ERROR PASESession::HandleMsg1_and_SendMsg2(System::PacketBufferHandle && ms
SuccessOrExit(err = tlvWriter.EndContainer(outerContainerType));
SuccessOrExit(err = tlvWriter.Finalize(&msg2));

err = mExchangeCtxt->SendMessage(MsgType::PASE_Pake2, std::move(msg2), SendFlags(SendMessageFlags::kExpectResponse));
err = mExchangeCtxt.Value()->SendMessage(MsgType::PASE_Pake2, std::move(msg2), SendFlags(SendMessageFlags::kExpectResponse));
SuccessOrExit(err);

mNextExpectedMsg.SetValue(MsgType::PASE_Pake3);
Expand Down Expand Up @@ -696,7 +697,7 @@ CHIP_ERROR PASESession::HandleMsg2_and_SendMsg3(System::PacketBufferHandle && ms
SuccessOrExit(err = tlvWriter.EndContainer(outerContainerType));
SuccessOrExit(err = tlvWriter.Finalize(&msg3));

err = mExchangeCtxt->SendMessage(MsgType::PASE_Pake3, std::move(msg3), SendFlags(SendMessageFlags::kExpectResponse));
err = mExchangeCtxt.Value()->SendMessage(MsgType::PASE_Pake3, std::move(msg3), SendFlags(SendMessageFlags::kExpectResponse));
SuccessOrExit(err);

mNextExpectedMsg.SetValue(MsgType::StatusReport);
Expand Down Expand Up @@ -785,25 +786,26 @@ CHIP_ERROR PASESession::ValidateReceivedMessage(ExchangeContext * exchange, cons
// mExchangeCtxt can be nullptr if this is the first message (PBKDFParamRequest) received by PASESession
// via UnsolicitedMessageHandler. The exchange context is allocated by exchange manager and provided
// to the handler (PASESession object).
if (mExchangeCtxt != nullptr)
if (mExchangeCtxt.HasValue())
{
if (mExchangeCtxt != exchange)
if (&mExchangeCtxt.Value().Get() != exchange)
{
ReturnErrorOnFailure(CHIP_ERROR_INVALID_ARGUMENT);
}
}
else
{
mExchangeCtxt = exchange;
ExchangeHandle ecHandle(*exchange);
mExchangeCtxt.SetValue(ecHandle);
}

if (!mExchangeCtxt->GetSessionHandle()->IsUnauthenticatedSession())
if (!mExchangeCtxt.Value()->GetSessionHandle()->IsUnauthenticatedSession())
{
ChipLogError(SecureChannel, "PASESession received PBKDFParamRequest over encrypted session. Ignoring.");
return CHIP_ERROR_INCORRECT_STATE;
}

mExchangeCtxt->UseSuggestedResponseTimeout(kExpectedHighProcessingTime);
mExchangeCtxt.Value()->UseSuggestedResponseTimeout(kExpectedHighProcessingTime);

VerifyOrReturnError(!msg.IsNull(), CHIP_ERROR_INVALID_ARGUMENT);
VerifyOrReturnError((mNextExpectedMsg.HasValue() && payloadHeader.HasMessageType(mNextExpectedMsg.Value())) ||
Expand Down Expand Up @@ -832,7 +834,7 @@ CHIP_ERROR PASESession::OnMessageReceived(ExchangeContext * exchange, const Payl
if (msgType == MsgType::PBKDFParamRequest || msgType == MsgType::PBKDFParamResponse || msgType == MsgType::PASE_Pake1 ||
msgType == MsgType::PASE_Pake2 || msgType == MsgType::PASE_Pake3)
{
SuccessOrExit(err = mExchangeCtxt->FlushAcks());
SuccessOrExit(err = mExchangeCtxt.Value()->FlushAcks());
}
#endif // CHIP_CONFIG_SLOW_CRYPTO

Expand Down
28 changes: 14 additions & 14 deletions src/protocols/secure_channel/PairingSession.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ CHIP_ERROR PairingSession::ActivateSecureSession(const Transport::PeerAddress &

void PairingSession::Finish()
{
Transport::PeerAddress address = mExchangeCtxt->GetSessionHandle()->AsUnauthenticatedSession()->GetPeerAddress();
Transport::PeerAddress address = mExchangeCtxt.Value()->GetSessionHandle()->AsUnauthenticatedSession()->GetPeerAddress();

// Discard the exchange so that Clear() doesn't try closing it. The exchange will handle that.
DiscardExchange();
Expand All @@ -79,14 +79,15 @@ void PairingSession::Finish()

void PairingSession::DiscardExchange()
{
if (mExchangeCtxt != nullptr)
if (mExchangeCtxt.HasValue())
{
// Make sure the exchange doesn't try to notify us when it closes,
// since we might be dead by then.
mExchangeCtxt->SetDelegate(nullptr);
mExchangeCtxt.Value()->SetDelegate(nullptr);

// Null out mExchangeCtxt so that Clear() doesn't try closing it. The
// exchange will handle that.
mExchangeCtxt = nullptr;
mExchangeCtxt.ClearValue();
}
}

Expand Down Expand Up @@ -227,19 +228,18 @@ bool PairingSession::IsSessionEstablishmentInProgress()

void PairingSession::Clear()
{
// Clear acts like the destructor if PairingSession, if it is call during
// middle of a pairing, means we should terminate the exchange. For normal
// path, the exchange should already be discarded before calling Clear.
if (mExchangeCtxt != nullptr)
// Clear acts like the destructor of PairingSession, if it is called during
// the middle of pairing, which means we should terminate the exchange. For the
// normal path, the exchange should already be discarded before calling Clear.
if (mExchangeCtxt.HasValue())
{
// The only time we reach this is if we are getting destroyed in the
// middle of our handshake. In that case, there is no point trying to
// do MRP resends of the last message we sent, so abort the exchange
// The only time we reach this is when we are getting destroyed in the
// middle of our handshake. In that case, there is no point in trying to
// do MRP resends of the last message we sent. So, abort the exchange
// instead of just closing it.
mExchangeCtxt->Abort();
mExchangeCtxt = nullptr;
mExchangeCtxt.Value()->Abort();
mExchangeCtxt.ClearValue();
}

mSecureSessionHolder.Release();
mPeerSessionId.ClearValue();
mSessionManager = nullptr;
Expand Down
Loading

0 comments on commit 912adec

Please sign in to comment.