From 1173911cafad82bb8c7c228e7d38232af8c25bc3 Mon Sep 17 00:00:00 2001 From: Pankaj Garg Date: Mon, 17 May 2021 10:39:27 -0700 Subject: [PATCH] Cleanup PASE code, error handling and error logging (#6852) --- src/lib/support/logging/CHIPLogging.cpp | 2 + src/lib/support/logging/Constants.h | 2 + src/protocols/secure_channel/PASESession.cpp | 171 ++++++++++--------- src/protocols/secure_channel/PASESession.h | 5 +- 4 files changed, 96 insertions(+), 84 deletions(-) diff --git a/src/lib/support/logging/CHIPLogging.cpp b/src/lib/support/logging/CHIPLogging.cpp index d70009e210b10f..1a51e17909e456 100644 --- a/src/lib/support/logging/CHIPLogging.cpp +++ b/src/lib/support/logging/CHIPLogging.cpp @@ -91,6 +91,8 @@ static const char ModuleNames[] = "-\0\0" // None "SPL" // SetupPayload "SVR" // AppServer "DIS" // Discovery + "PAS" // PASE + "CAS" // CASE ; #define ModuleNamesCount ((sizeof(ModuleNames) - 1) / chip::Logging::kMaxModuleNameLen) diff --git a/src/lib/support/logging/Constants.h b/src/lib/support/logging/Constants.h index c605e2fd93214b..5b7f452b97ca67 100644 --- a/src/lib/support/logging/Constants.h +++ b/src/lib/support/logging/Constants.h @@ -54,6 +54,8 @@ enum LogModule kLogModule_SetupPayload, kLogModule_AppServer, kLogModule_Discovery, + kLogModule_PASE, + kLogModule_CASE, kLogModule_Max }; diff --git a/src/protocols/secure_channel/PASESession.cpp b/src/protocols/secure_channel/PASESession.cpp index 65706c3bd555b9..dc532b67bc374c 100644 --- a/src/protocols/secure_channel/PASESession.cpp +++ b/src/protocols/secure_channel/PASESession.cpp @@ -40,6 +40,7 @@ #include #include #include +#include #include #include @@ -184,7 +185,7 @@ CHIP_ERROR PASESession::Init(uint16_t myKeyId, uint32_t setupCode, SessionEstabl mDelegate = delegate; - ChipLogDetail(Ble, "Assigned local session key ID %d", myKeyId); + ChipLogDetail(PASE, "Assigned local session key ID %d", myKeyId); mConnectionState.SetLocalKeyID(myKeyId); mSetupPINCode = setupCode; mComputeVerifier = true; @@ -245,12 +246,10 @@ CHIP_ERROR PASESession::SetupSpake2p(uint32_t pbkdf2IterCount, const uint8_t * s CHIP_ERROR PASESession::WaitForPairing(uint32_t mySetUpPINCode, uint32_t pbkdf2IterCount, const uint8_t * salt, size_t saltLen, uint16_t myKeyId, SessionEstablishmentDelegate * delegate) { - CHIP_ERROR err = CHIP_NO_ERROR; - - VerifyOrExit(salt != nullptr, err = CHIP_ERROR_INVALID_ARGUMENT); - VerifyOrExit(saltLen > 0, err = CHIP_ERROR_INVALID_ARGUMENT); + ReturnErrorCodeIf(salt == nullptr, CHIP_ERROR_INVALID_ARGUMENT); + ReturnErrorCodeIf(saltLen == 0, CHIP_ERROR_INVALID_ARGUMENT); - err = Init(myKeyId, mySetUpPINCode, delegate); + CHIP_ERROR err = Init(myKeyId, mySetUpPINCode, delegate); SuccessOrExit(err); VerifyOrExit(CanCastTo(saltLen), err = CHIP_ERROR_INVALID_ARGUMENT); @@ -272,7 +271,7 @@ CHIP_ERROR PASESession::WaitForPairing(uint32_t mySetUpPINCode, uint32_t pbkdf2I mNextExpectedMsg = Protocols::SecureChannel::MsgType::PBKDFParamRequest; mPairingComplete = false; - ChipLogDetail(Ble, "Waiting for PBKDF param request"); + ChipLogDetail(PASE, "Waiting for PBKDF param request"); exit: if (err != CHIP_NO_ERROR) @@ -284,19 +283,14 @@ CHIP_ERROR PASESession::WaitForPairing(uint32_t mySetUpPINCode, uint32_t pbkdf2I CHIP_ERROR PASESession::WaitForPairing(const PASEVerifier & verifier, uint16_t myKeyId, SessionEstablishmentDelegate * delegate) { - CHIP_ERROR err = WaitForPairing(0, kSpake2p_Iteration_Count, reinterpret_cast(kSpake2pKeyExchangeSalt), - strlen(kSpake2pKeyExchangeSalt), myKeyId, delegate); - SuccessOrExit(err); + ReturnErrorOnFailure(WaitForPairing(0, kSpake2p_Iteration_Count, + reinterpret_cast(kSpake2pKeyExchangeSalt), + strlen(kSpake2pKeyExchangeSalt), myKeyId, delegate)); memmove(&mPASEVerifier, verifier, sizeof(verifier)); mComputeVerifier = false; -exit: - if (err != CHIP_NO_ERROR) - { - Clear(); - } - return err; + return CHIP_NO_ERROR; } CHIP_ERROR PASESession::Pair(const Transport::PeerAddress peerAddress, uint32_t peerSetUpPINCode, uint16_t myKeyId, @@ -324,12 +318,13 @@ CHIP_ERROR PASESession::Pair(const Transport::PeerAddress peerAddress, uint32_t void PASESession::OnResponseTimeout(ExchangeContext * ec) { - VerifyOrReturn(ec != nullptr, ChipLogError(Ble, "PASESession::OnResponseTimeout was called by null exchange")); + VerifyOrReturn(ec != nullptr, ChipLogError(PASE, "PASESession::OnResponseTimeout was called by null exchange")); VerifyOrReturn(mExchangeCtxt == nullptr || mExchangeCtxt == ec, - ChipLogError(Ble, "PASESession::OnResponseTimeout exchange doesn't match")); - ChipLogError(Ble, "PASESession timed out while waiting for a response from the peer. Expected message type was %d", + ChipLogError(PASE, "PASESession::OnResponseTimeout exchange doesn't match")); + ChipLogError(PASE, "PASESession timed out while waiting for a response from the peer. Expected message type was %d", mNextExpectedMsg); mDelegate->OnSessionEstablishmentError(CHIP_ERROR_TIMEOUT); + Clear(); } CHIP_ERROR PASESession::DeriveSecureSession(SecureSession & session, SecureSession::SessionRole role) @@ -341,35 +336,24 @@ CHIP_ERROR PASESession::DeriveSecureSession(SecureSession & session, SecureSessi CHIP_ERROR PASESession::SendPBKDFParamRequest() { - CHIP_ERROR err = CHIP_NO_ERROR; - System::PacketBufferHandle req = System::PacketBufferHandle::New(kPBKDFParamRandomNumberSize); - VerifyOrExit(!req.IsNull(), err = CHIP_SYSTEM_ERROR_NO_MEMORY); + VerifyOrReturnError(!req.IsNull(), CHIP_SYSTEM_ERROR_NO_MEMORY); - err = DRBG_get_bytes(req->Start(), kPBKDFParamRandomNumberSize); - SuccessOrExit(err); + ReturnErrorOnFailure(DRBG_get_bytes(req->Start(), kPBKDFParamRandomNumberSize)); req->SetDataLength(kPBKDFParamRandomNumberSize); // Update commissioning hash with the pbkdf2 param request that's being sent. - err = mCommissioningHash.AddData(req->Start(), req->DataLength()); - SuccessOrExit(err); + ReturnErrorOnFailure(mCommissioningHash.AddData(req->Start(), req->DataLength())); mNextExpectedMsg = Protocols::SecureChannel::MsgType::PBKDFParamResponse; - err = mExchangeCtxt->SendMessage(Protocols::SecureChannel::MsgType::PBKDFParamRequest, std::move(req), - SendFlags(SendMessageFlags::kExpectResponse)); - SuccessOrExit(err); - - ChipLogDetail(Ble, "Sent PBKDF param request"); + ReturnErrorOnFailure(mExchangeCtxt->SendMessage(Protocols::SecureChannel::MsgType::PBKDFParamRequest, std::move(req), + SendFlags(SendMessageFlags::kExpectResponse))); -exit: + ChipLogDetail(PASE, "Sent PBKDF param request"); - if (err != CHIP_NO_ERROR) - { - Clear(); - } - return err; + return CHIP_NO_ERROR; } CHIP_ERROR PASESession::HandlePBKDFParamRequest(const System::PacketBufferHandle & msg) @@ -383,7 +367,7 @@ CHIP_ERROR PASESession::HandlePBKDFParamRequest(const System::PacketBufferHandle VerifyOrExit(req != nullptr, err = CHIP_ERROR_MESSAGE_INCOMPLETE); VerifyOrExit(reqlen == kPBKDFParamRandomNumberSize, err = CHIP_ERROR_INVALID_MESSAGE_LENGTH); - ChipLogDetail(Ble, "Received PBKDF param request"); + ChipLogDetail(PASE, "Received PBKDF param request"); // Update commissioning hash with the received pbkdf2 param request err = mCommissioningHash.AddData(req, reqlen); @@ -439,7 +423,7 @@ CHIP_ERROR PASESession::SendPBKDFParamResponse() ReturnErrorOnFailure(mExchangeCtxt->SendMessage(Protocols::SecureChannel::MsgType::PBKDFParamResponse, std::move(resp), SendFlags(SendMessageFlags::kExpectResponse))); - ChipLogDetail(Ble, "Sent PBKDF param response"); + ChipLogDetail(PASE, "Sent PBKDF param response"); return CHIP_NO_ERROR; } @@ -457,7 +441,7 @@ CHIP_ERROR PASESession::HandlePBKDFParamResponse(const System::PacketBufferHandl static_assert(CHAR_BIT == 8, "Assuming that sizeof returns octets"); size_t fixed_resplen = kPBKDFParamRandomNumberSize + sizeof(uint64_t) + sizeof(uint32_t); - ChipLogDetail(Ble, "Received PBKDF param response"); + ChipLogDetail(PASE, "Received PBKDF param response"); VerifyOrExit(resp != nullptr, err = CHIP_ERROR_MESSAGE_INCOMPLETE); VerifyOrExit(resplen >= fixed_resplen, err = CHIP_ERROR_INVALID_MESSAGE_LENGTH); @@ -514,7 +498,7 @@ CHIP_ERROR PASESession::SendMsg1() // Call delegate to send the Msg1 to peer ReturnErrorOnFailure(mExchangeCtxt->SendMessage(Protocols::SecureChannel::MsgType::PASE_Spake2p1, bbuf.Finalize(), SendFlags(SendMessageFlags::kExpectResponse))); - ChipLogDetail(Ble, "Sent spake2p msg1"); + ChipLogDetail(PASE, "Sent spake2p msg1"); return CHIP_NO_ERROR; } @@ -536,7 +520,7 @@ CHIP_ERROR PASESession::HandleMsg1_and_SendMsg2(const System::PacketBufferHandle uint16_t encryptionKeyId = 0; - ChipLogDetail(Ble, "Received spake2p msg1"); + ChipLogDetail(PASE, "Received spake2p msg1"); VerifyOrExit(buf != nullptr, err = CHIP_ERROR_MESSAGE_INCOMPLETE); VerifyOrExit(buf_len == sizeof(encryptionKeyId) + kMAX_Point_Length, err = CHIP_ERROR_INVALID_MESSAGE_LENGTH); @@ -551,7 +535,7 @@ CHIP_ERROR PASESession::HandleMsg1_and_SendMsg2(const System::PacketBufferHandle err = mSpake2p.ComputeRoundOne(msg->Start(), msg->DataLength(), Y, &Y_len); SuccessOrExit(err); - ChipLogDetail(Ble, "Peer assigned session key ID %d", encryptionKeyId); + ChipLogDetail(PASE, "Peer assigned session key ID %d", encryptionKeyId); mConnectionState.SetPeerKeyID(encryptionKeyId); err = mSpake2p.ComputeRoundTwo(msg->Start(), msg->DataLength(), verifier, &verifier_len); @@ -578,7 +562,7 @@ CHIP_ERROR PASESession::HandleMsg1_and_SendMsg2(const System::PacketBufferHandle SuccessOrExit(err); } - ChipLogDetail(Ble, "Sent spake2p msg2"); + ChipLogDetail(PASE, "Sent spake2p msg2"); exit: @@ -606,7 +590,7 @@ CHIP_ERROR PASESession::HandleMsg2_and_SendMsg3(const System::PacketBufferHandle uint16_t encryptionKeyId = 0; - ChipLogDetail(Ble, "Received spake2p msg2"); + ChipLogDetail(PASE, "Received spake2p msg2"); VerifyOrExit(buf != nullptr, err = CHIP_ERROR_MESSAGE_INCOMPLETE); VerifyOrExit(buf_len == sizeof(encryptionKeyId) + kMAX_Point_Length + kMAX_Hash_Length, @@ -617,7 +601,7 @@ CHIP_ERROR PASESession::HandleMsg2_and_SendMsg3(const System::PacketBufferHandle buf = msg->Start(); buf_len = msg->DataLength(); - ChipLogDetail(Ble, "Peer assigned session key ID %d", encryptionKeyId); + ChipLogDetail(PASE, "Peer assigned session key ID %d", encryptionKeyId); mConnectionState.SetPeerKeyID(encryptionKeyId); err = mSpake2p.ComputeRoundTwo(buf, kMAX_Point_Length, verifier, &verifier_len_raw); @@ -637,7 +621,7 @@ CHIP_ERROR PASESession::HandleMsg2_and_SendMsg3(const System::PacketBufferHandle SuccessOrExit(err); } - ChipLogDetail(Ble, "Sent spake2p msg3"); + ChipLogDetail(PASE, "Sent spake2p msg3"); { const uint8_t * hash = &buf[kMAX_Point_Length]; @@ -654,6 +638,13 @@ CHIP_ERROR PASESession::HandleMsg2_and_SendMsg3(const System::PacketBufferHandle mPairingComplete = true; + // Close the exchange, as no additional messages are expected from the peer + if (mExchangeCtxt != nullptr) + { + mExchangeCtxt->Close(); + mExchangeCtxt = nullptr; + } + // Call delegate to indicate pairing completion mDelegate->OnSessionEstablished(); @@ -672,7 +663,7 @@ CHIP_ERROR PASESession::HandleMsg3(const System::PacketBufferHandle & msg) const uint8_t * hash = msg->Start(); Spake2pErrorType spake2pErr = Spake2pErrorType::kUnexpected; - ChipLogDetail(Ble, "Received spake2p msg3"); + ChipLogDetail(PASE, "Received spake2p msg3"); // We will set NextExpectedMsg to PASE_Spake2pError in all cases // However, when we are using IP rendezvous, we might set it to PASE_Spake2p1. @@ -693,6 +684,13 @@ CHIP_ERROR PASESession::HandleMsg3(const System::PacketBufferHandle & msg) mPairingComplete = true; + // Close the exchange, as no additional messages are expected from the peer + if (mExchangeCtxt != nullptr) + { + mExchangeCtxt->Close(); + mExchangeCtxt = nullptr; + } + // Call delegate to indicate pairing completion mDelegate->OnSessionEstablished(); @@ -707,64 +705,69 @@ CHIP_ERROR PASESession::HandleMsg3(const System::PacketBufferHandle & msg) void PASESession::SendErrorMsg(Spake2pErrorType errorCode) { - CHIP_ERROR err = CHIP_NO_ERROR; - System::PacketBufferHandle msg; uint16_t msglen = sizeof(Spake2pErrorMsg); Spake2pErrorMsg * pMsg = nullptr; msg = System::PacketBufferHandle::New(msglen); - VerifyOrExit(!msg.IsNull(), err = CHIP_SYSTEM_ERROR_NO_MEMORY); + VerifyOrReturn(!msg.IsNull(), ChipLogError(PASE, "Failed to allocate error message")); pMsg = reinterpret_cast(msg->Start()); pMsg->error = errorCode; msg->SetDataLength(msglen); - err = mExchangeCtxt->SendMessage(Protocols::SecureChannel::MsgType::PASE_Spake2pError, std::move(msg)); - SuccessOrExit(err); - -exit: - Clear(); + VerifyOrReturn(mExchangeCtxt->SendMessage(Protocols::SecureChannel::MsgType::PASE_Spake2pError, std::move(msg)), + ChipLogError(PASE, "Failed to send error message")); } -void PASESession::HandleErrorMsg(const System::PacketBufferHandle & msg) +CHIP_ERROR PASESession::HandleErrorMsg(const System::PacketBufferHandle & msg) { - // Request message processing - const uint8_t * buf = msg->Start(); - size_t buflen = msg->DataLength(); - Spake2pErrorMsg * pMsg = nullptr; - - VerifyOrExit(buf != nullptr, ChipLogError(Ble, "Null error msg received during pairing")); - VerifyOrExit(buflen == sizeof(Spake2pErrorMsg), ChipLogError(Ble, "Error msg with incorrect length received during pairing")); + ReturnErrorCodeIf(msg->Start() == nullptr || msg->DataLength() != sizeof(Spake2pErrorMsg), CHIP_ERROR_MESSAGE_INCOMPLETE); - pMsg = reinterpret_cast(msg->Start()); - ChipLogError(Ble, "Received error (%d) during pairing process", pMsg->error); + Spake2pErrorMsg * pMsg = reinterpret_cast(msg->Start()); + ChipLogError(PASE, "Received error during pairing process. %s", ErrorStr(pMsg->error)); - mDelegate->OnSessionEstablishmentError(pMsg->error); - -exit: - Clear(); + return pMsg->error; } -void PASESession::OnMessageReceived(ExchangeContext * ec, const PacketHeader & packetHeader, const PayloadHeader & payloadHeader, - System::PacketBufferHandle msg) +CHIP_ERROR PASESession::ValidateReceivedMessage(ExchangeContext * exchange, const PacketHeader & packetHeader, + const PayloadHeader & payloadHeader, System::PacketBufferHandle & msg) { - CHIP_ERROR err = CHIP_NO_ERROR; + VerifyOrReturnError(exchange != nullptr, CHIP_ERROR_INVALID_ARGUMENT); - VerifyOrExit(ec != nullptr, err = CHIP_ERROR_INVALID_ARGUMENT); - VerifyOrExit(mExchangeCtxt == nullptr || mExchangeCtxt == ec, err = CHIP_ERROR_INVALID_ARGUMENT); - VerifyOrExit(!msg.IsNull(), err = CHIP_ERROR_INVALID_ARGUMENT); - VerifyOrExit(payloadHeader.HasMessageType(mNextExpectedMsg) || - payloadHeader.HasMessageType(Protocols::SecureChannel::MsgType::PASE_Spake2pError), - err = CHIP_ERROR_INVALID_MESSAGE_TYPE); - - if (mExchangeCtxt == nullptr) + // mExchangeCtxt can be nullptr if this is the first message (PBKDFParamRequest) received by PASESession + // via UnsolicitedMessageHandler. The exchange context is allocated by exchange manager and provided + // to the handler (PASESession object). + if (mExchangeCtxt != nullptr) + { + if (mExchangeCtxt != exchange) + { + // Close the incoming exchange explicitly, as the cleanup code only closes mExchangeCtxt + exchange->Close(); + ReturnErrorOnFailure(CHIP_ERROR_INVALID_ARGUMENT); + } + } + else { - mExchangeCtxt = ec; + mExchangeCtxt = exchange; mExchangeCtxt->SetResponseTimeout(kSpake2p_Response_Timeout); } + VerifyOrReturnError(!msg.IsNull(), CHIP_ERROR_INVALID_ARGUMENT); + VerifyOrReturnError(payloadHeader.HasMessageType(mNextExpectedMsg) || + payloadHeader.HasMessageType(Protocols::SecureChannel::MsgType::PASE_Spake2pError), + CHIP_ERROR_INVALID_MESSAGE_TYPE); + + return CHIP_NO_ERROR; +} + +void PASESession::OnMessageReceived(ExchangeContext * exchange, const PacketHeader & packetHeader, + const PayloadHeader & payloadHeader, System::PacketBufferHandle msg) +{ + CHIP_ERROR err = ValidateReceivedMessage(exchange, packetHeader, payloadHeader, msg); + SuccessOrExit(err); + mConnectionState.SetPeerAddress(mMessageDispatch.GetPeerAddress()); switch (static_cast(payloadHeader.GetMessageType())) @@ -790,7 +793,7 @@ void PASESession::OnMessageReceived(ExchangeContext * ec, const PacketHeader & p break; case Protocols::SecureChannel::MsgType::PASE_Spake2pError: - HandleErrorMsg(msg); + err = HandleErrorMsg(msg); break; default: @@ -803,6 +806,8 @@ void PASESession::OnMessageReceived(ExchangeContext * ec, const PacketHeader & p // Call delegate to indicate pairing failure if (err != CHIP_NO_ERROR) { + Clear(); + ChipLogError(PASE, "Failed during PASE session setup. %s", ErrorStr(err)); mDelegate->OnSessionEstablishmentError(err); } } diff --git a/src/protocols/secure_channel/PASESession.h b/src/protocols/secure_channel/PASESession.h index f65df90186a99c..b939a8272785f5 100644 --- a/src/protocols/secure_channel/PASESession.h +++ b/src/protocols/secure_channel/PASESession.h @@ -243,6 +243,9 @@ class DLL_EXPORT PASESession : public Messaging::ExchangeDelegateBase, public Pa CHIP_ERROR Init(uint16_t myKeyId, uint32_t setupCode, SessionEstablishmentDelegate * delegate); + CHIP_ERROR ValidateReceivedMessage(Messaging::ExchangeContext * exchange, const PacketHeader & packetHeader, + const PayloadHeader & payloadHeader, System::PacketBufferHandle & msg); + static CHIP_ERROR ComputePASEVerifier(uint32_t mySetUpPINCode, uint32_t pbkdf2IterCount, const uint8_t * salt, size_t saltLen, PASEVerifier & verifier); @@ -261,7 +264,7 @@ class DLL_EXPORT PASESession : public Messaging::ExchangeDelegateBase, public Pa CHIP_ERROR HandleMsg3(const System::PacketBufferHandle & msg); void SendErrorMsg(Spake2pErrorType errorCode); - void HandleErrorMsg(const System::PacketBufferHandle & msg); + CHIP_ERROR HandleErrorMsg(const System::PacketBufferHandle & msg); SessionEstablishmentDelegate * mDelegate = nullptr;