From 8c7d16db10925aab90c25ecb37344ba35b8efe43 Mon Sep 17 00:00:00 2001 From: Pankaj Garg Date: Thu, 16 Dec 2021 13:29:13 -0800 Subject: [PATCH] Update PASE state machine's state after successful send (#13090) --- src/protocols/secure_channel/PASESession.cpp | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/src/protocols/secure_channel/PASESession.cpp b/src/protocols/secure_channel/PASESession.cpp index 24cc043cb6e9ff..ba62a8aa442c8f 100644 --- a/src/protocols/secure_channel/PASESession.cpp +++ b/src/protocols/secure_channel/PASESession.cpp @@ -392,11 +392,11 @@ CHIP_ERROR PASESession::SendPBKDFParamRequest() // Update commissioning hash with the pbkdf2 param request that's being sent. ReturnErrorOnFailure(mCommissioningHash.AddData(ByteSpan{ req->Start(), req->DataLength() })); - mNextExpectedMsg = MsgType::PBKDFParamResponse; - ReturnErrorOnFailure( mExchangeCtxt->SendMessage(MsgType::PBKDFParamRequest, std::move(req), SendFlags(SendMessageFlags::kExpectResponse))); + mNextExpectedMsg = MsgType::PBKDFParamResponse; + ChipLogDetail(SecureChannel, "Sent PBKDF param request"); return CHIP_NO_ERROR; @@ -512,12 +512,12 @@ CHIP_ERROR PASESession::SendPBKDFParamResponse(ByteSpan initiatorRandom, bool in size_t sizeof_point = sizeof(mPoint); ReturnErrorOnFailure(mSpake2p.ComputeL(mPoint, &sizeof_point, mPASEVerifier.mL, kSpake2p_WS_Length)); - mNextExpectedMsg = MsgType::PASE_Pake1; - ReturnErrorOnFailure( mExchangeCtxt->SendMessage(MsgType::PBKDFParamResponse, std::move(resp), SendFlags(SendMessageFlags::kExpectResponse))); ChipLogDetail(SecureChannel, "Sent PBKDF param response"); + mNextExpectedMsg = MsgType::PASE_Pake1; + return CHIP_NO_ERROR; } @@ -636,12 +636,12 @@ CHIP_ERROR PASESession::SendMsg1() ReturnErrorOnFailure(tlvWriter.EndContainer(outerContainerType)); ReturnErrorOnFailure(tlvWriter.Finalize(&msg)); - mNextExpectedMsg = MsgType::PASE_Pake2; - ReturnErrorOnFailure( mExchangeCtxt->SendMessage(MsgType::PASE_Pake1, std::move(msg), SendFlags(SendMessageFlags::kExpectResponse))); ChipLogDetail(SecureChannel, "Sent spake2p msg1"); + mNextExpectedMsg = MsgType::PASE_Pake2; + return CHIP_NO_ERROR; } @@ -697,10 +697,10 @@ CHIP_ERROR PASESession::HandleMsg1_and_SendMsg2(System::PacketBufferHandle && ms SuccessOrExit(err = tlvWriter.EndContainer(outerContainerType)); SuccessOrExit(err = tlvWriter.Finalize(&msg2)); - mNextExpectedMsg = MsgType::PASE_Pake3; - err = mExchangeCtxt->SendMessage(MsgType::PASE_Pake2, std::move(msg2), SendFlags(SendMessageFlags::kExpectResponse)); SuccessOrExit(err); + + mNextExpectedMsg = MsgType::PASE_Pake3; } ChipLogDetail(SecureChannel, "Sent spake2p msg2"); @@ -772,10 +772,10 @@ CHIP_ERROR PASESession::HandleMsg2_and_SendMsg3(System::PacketBufferHandle && ms SuccessOrExit(err = tlvWriter.EndContainer(outerContainerType)); SuccessOrExit(err = tlvWriter.Finalize(&msg3)); - mNextExpectedMsg = MsgType::StatusReport; - err = mExchangeCtxt->SendMessage(MsgType::PASE_Pake3, std::move(msg3), SendFlags(SendMessageFlags::kExpectResponse)); SuccessOrExit(err); + + mNextExpectedMsg = MsgType::StatusReport; } ChipLogDetail(SecureChannel, "Sent spake2p msg3");