diff --git a/src/protocols/secure_channel/CASESession.cpp b/src/protocols/secure_channel/CASESession.cpp index fa0ff9b0333364..61e9cb88f7ace4 100644 --- a/src/protocols/secure_channel/CASESession.cpp +++ b/src/protocols/secure_channel/CASESession.cpp @@ -103,6 +103,11 @@ void CASESession::Clear() mTrustedRootId.mId = nullptr; } + CloseExchange(); +} + +void CASESession::CloseExchange() +{ if (mExchangeCtxt != nullptr) { mExchangeCtxt->Close(); @@ -220,7 +225,7 @@ CASESession::ListenForSessionEstablishment(OperationalCredentialSet * operationa mNextExpectedMsg = Protocols::SecureChannel::MsgType::CASE_SigmaR1; mPairingComplete = false; - ChipLogDetail(Inet, "Waiting for SigmaR1 msg"); + ChipLogDetail(SecureChannel, "Waiting for SigmaR1 msg"); return CHIP_NO_ERROR; } @@ -231,6 +236,7 @@ CHIP_ERROR CASESession::EstablishSession(const Transport::PeerAddress peerAddres { CHIP_ERROR err = CHIP_NO_ERROR; + // Return early on error here, as we have not initalized any state yet ReturnErrorCodeIf(exchangeCtxt == nullptr, CHIP_ERROR_INVALID_ARGUMENT); err = Init(operationalCredentialSet, myKeyId, delegate); @@ -238,6 +244,9 @@ CHIP_ERROR CASESession::EstablishSession(const Transport::PeerAddress peerAddres // We are setting the exchange context specifically before checking for error. // This is to make sure the exchange will get closed if Init() returned an error. mExchangeCtxt = exchangeCtxt; + + // From here onwards, let's go to exit on error, as some state might have already + // been initialized SuccessOrExit(err); mExchangeCtxt->SetResponseTimeout(kSigma_Response_Timeout); @@ -257,9 +266,9 @@ CHIP_ERROR CASESession::EstablishSession(const Transport::PeerAddress peerAddres void CASESession::OnResponseTimeout(ExchangeContext * ec) { - VerifyOrReturn(ec != nullptr, ChipLogError(Inet, "CASESession::OnResponseTimeout was called by null exchange")); - VerifyOrReturn(mExchangeCtxt == ec, ChipLogError(Inet, "CASESession::OnResponseTimeout exchange doesn't match")); - ChipLogError(Inet, "CASESession timed out while waiting for a response from the peer. Expected message type was %d", + VerifyOrReturn(ec != nullptr, ChipLogError(SecureChannel, "CASESession::OnResponseTimeout was called by null exchange")); + VerifyOrReturn(mExchangeCtxt == ec, ChipLogError(SecureChannel, "CASESession::OnResponseTimeout exchange doesn't match")); + ChipLogError(SecureChannel, "CASESession timed out while waiting for a response from the peer. Expected message type was %d", mNextExpectedMsg); mDelegate->OnSessionEstablishmentError(CHIP_ERROR_TIMEOUT); Clear(); @@ -347,7 +356,7 @@ CHIP_ERROR CASESession::SendSigmaR1() ReturnErrorOnFailure(mExchangeCtxt->SendMessage(Protocols::SecureChannel::MsgType::CASE_SigmaR1, std::move(msg_R1), SendFlags(SendMessageFlags::kExpectResponse))); - ChipLogDetail(Inet, "Sent SigmaR1 msg"); + ChipLogDetail(SecureChannel, "Sent SigmaR1 msg"); return CHIP_NO_ERROR; } @@ -377,7 +386,7 @@ CHIP_ERROR CASESession::HandleSigmaR1(const System::PacketBufferHandle & msg) VerifyOrExit(buf != nullptr, err = CHIP_ERROR_MESSAGE_INCOMPLETE); VerifyOrExit(buflen >= fixed_buflen, err = CHIP_ERROR_INVALID_MESSAGE_LENGTH); - ChipLogDetail(Inet, "Received SigmaR1 msg"); + ChipLogDetail(SecureChannel, "Received SigmaR1 msg"); err = mCommissioningHash.AddData(msg->Start(), msg->DataLength()); SuccessOrExit(err); @@ -394,7 +403,7 @@ CHIP_ERROR CASESession::HandleSigmaR1(const System::PacketBufferHandle & msg) bbuf.Put(buf, kP256_PublicKey_Length); VerifyOrExit(bbuf.Fit(), err = CHIP_ERROR_NO_MEMORY); - ChipLogDetail(Inet, "Peer assigned session key ID %d", encryptionKeyId); + ChipLogDetail(SecureChannel, "Peer assigned session key ID %d", encryptionKeyId); mConnectionState.SetPeerKeyID(encryptionKeyId); exit: @@ -553,7 +562,7 @@ CHIP_ERROR CASESession::SendSigmaR2() SendFlags(SendMessageFlags::kExpectResponse)); SuccessOrExit(err); - ChipLogDetail(Inet, "Sent SigmaR2 msg"); + ChipLogDetail(SecureChannel, "Sent SigmaR2 msg"); exit: @@ -605,7 +614,7 @@ CHIP_ERROR CASESession::HandleSigmaR2(const System::PacketBufferHandle & msg) VerifyOrExit(buf != nullptr, err = CHIP_ERROR_MESSAGE_INCOMPLETE); - ChipLogDetail(Inet, "Received SigmaR2 msg"); + ChipLogDetail(SecureChannel, "Received SigmaR2 msg"); // Step 1 // skip random part @@ -613,7 +622,7 @@ CHIP_ERROR CASESession::HandleSigmaR2(const System::PacketBufferHandle & msg) encryptionKeyId = chip::Encoding::LittleEndian::Read16(buf); - ChipLogDetail(Inet, "Peer assigned session key ID %d", encryptionKeyId); + ChipLogDetail(SecureChannel, "Peer assigned session key ID %d", encryptionKeyId); mConnectionState.SetPeerKeyID(encryptionKeyId); err = FindValidTrustedRoot(&buf, 1); @@ -725,7 +734,7 @@ CHIP_ERROR CASESession::SendSigmaR3() // Step 1 saltlen = kIPKSize + kSHA256_Hash_Length; - ChipLogDetail(Inet, "Sending SigmaR3"); + ChipLogDetail(SecureChannel, "Sending SigmaR3"); msg_salt = System::PacketBufferHandle::New(saltlen); VerifyOrExit(!msg_salt.IsNull(), err = CHIP_SYSTEM_ERROR_NO_MEMORY); msg_salt->SetDataLength(saltlen); @@ -809,13 +818,16 @@ CHIP_ERROR CASESession::SendSigmaR3() err = mExchangeCtxt->SendMessage(Protocols::SecureChannel::MsgType::CASE_SigmaR3, std::move(msg_R3)); SuccessOrExit(err); - ChipLogDetail(Inet, "Sent SigmaR3 msg"); + ChipLogDetail(SecureChannel, "Sent SigmaR3 msg"); err = mCommissioningHash.Finish(mMessageDigest); SuccessOrExit(err); mPairingComplete = true; + // Close the exchange, as no additional messages are expected from the peer + CloseExchange(); + // Call delegate to indicate pairing completion mDelegate->OnSessionEstablished(); @@ -854,7 +866,7 @@ CHIP_ERROR CASESession::HandleSigmaR3(const System::PacketBufferHandle & msg) HKDF_sha_crypto mHKDF; - ChipLogDetail(Inet, "Received SigmaR3 msg"); + ChipLogDetail(SecureChannel, "Received SigmaR3 msg"); mNextExpectedMsg = Protocols::SecureChannel::MsgType::CASE_SigmaErr; @@ -908,6 +920,9 @@ CHIP_ERROR CASESession::HandleSigmaR3(const System::PacketBufferHandle & msg) mPairingComplete = true; + // Close the exchange, as no additional messages are expected from the peer + CloseExchange(); + // Call delegate to indicate pairing completion mDelegate->OnSessionEstablished(); @@ -925,25 +940,20 @@ CHIP_ERROR CASESession::HandleSigmaR3(const System::PacketBufferHandle & msg) void CASESession::SendErrorMsg(SigmaErrorType errorCode) { - CHIP_ERROR err = CHIP_NO_ERROR; - System::PacketBufferHandle msg; uint16_t msglen = sizeof(SigmaErrorMsg); SigmaErrorMsg * pMsg = nullptr; msg = System::PacketBufferHandle::New(msglen); - VerifyOrExit(!msg.IsNull(), err = CHIP_SYSTEM_ERROR_NO_MEMORY); + VerifyOrReturn(!msg.IsNull(), ChipLogError(SecureChannel, "Failed to allocate error message")); pMsg = reinterpret_cast(msg->Start()); pMsg->error = errorCode; msg->SetDataLength(msglen); - err = mExchangeCtxt->SendMessage(Protocols::SecureChannel::MsgType::CASE_SigmaErr, std::move(msg)); - SuccessOrExit(err); - -exit: - Clear(); + VerifyOrReturn(mExchangeCtxt->SendMessage(Protocols::SecureChannel::MsgType::CASE_SigmaErr, std::move(msg)) != CHIP_NO_ERROR, + ChipLogError(SecureChannel, "Failed to send error message")); } CHIP_ERROR CASESession::FindValidTrustedRoot(const uint8_t ** msgIterator, uint32_t nTrustedRoots) @@ -1090,23 +1100,39 @@ CHIP_ERROR CASESession::SetEffectiveTime(void) return ASN1ToChipEpochTime(effectiveTime, mValidContext.mEffectiveTime); } -void CASESession::HandleErrorMsg(const System::PacketBufferHandle & msg) +CHIP_ERROR CASESession::HandleErrorMsg(const System::PacketBufferHandle & msg) { - // Error message processing - const uint8_t * buf = msg->Start(); - size_t buflen = msg->DataLength(); - SigmaErrorMsg * pMsg = nullptr; + ReturnErrorCodeIf(msg->Start() == nullptr || msg->DataLength() != sizeof(SigmaErrorMsg), CHIP_ERROR_MESSAGE_INCOMPLETE); - VerifyOrExit(buf != nullptr, ChipLogError(Inet, "Null error msg received during pairing")); static_assert(sizeof(SigmaErrorMsg) == sizeof(uint8_t), "Assuming size of SigmaErrorMsg message is 1 octet, so that endian-ness conversion is not needed"); - VerifyOrExit(buflen == sizeof(SigmaErrorMsg), ChipLogError(Inet, "Error msg with incorrect length received during pairing")); - pMsg = reinterpret_cast(msg->Start()); - ChipLogError(Inet, "Received error (%d) during CASE pairing process", pMsg->error); + SigmaErrorMsg * pMsg = reinterpret_cast(msg->Start()); + ChipLogError(SecureChannel, "Received error (%d) during CASE pairing process", pMsg->error); -exit: - Clear(); + CHIP_ERROR err = CHIP_NO_ERROR; + switch (pMsg->error) + { + case SigmaErrorType::kNoSharedTrustRoots: + err = CHIP_ERROR_CERT_NOT_TRUSTED; + break; + + case SigmaErrorType::kUnsupportedVersion: + err = CHIP_ERROR_UNSUPPORTED_CASE_CONFIGURATION; + break; + + case SigmaErrorType::kInvalidSignature: + case SigmaErrorType::kInvalidResumptionTag: + case SigmaErrorType::kUnexpected: + err = CHIP_ERROR_INVALID_CASE_PARAMETER; + break; + + default: + err = CHIP_ERROR_INTERNAL; + break; + }; + + return err; } CHIP_ERROR CASESession::ValidateReceivedMessage(ExchangeContext * ec, const PacketHeader & packetHeader, @@ -1157,12 +1183,7 @@ void CASESession::OnMessageReceived(ExchangeContext * ec, const PacketHeader & p System::PacketBufferHandle && msg) { CHIP_ERROR err = ValidateReceivedMessage(ec, packetHeader, payloadHeader, msg); - - if (err != CHIP_NO_ERROR) - { - Clear(); - SuccessOrExit(err); - } + SuccessOrExit(err); mConnectionState.SetPeerAddress(mMessageDispatch.GetPeerAddress()); @@ -1181,7 +1202,7 @@ void CASESession::OnMessageReceived(ExchangeContext * ec, const PacketHeader & p break; case Protocols::SecureChannel::MsgType::CASE_SigmaErr: - HandleErrorMsg(msg); + err = HandleErrorMsg(msg); break; default: @@ -1195,6 +1216,7 @@ void CASESession::OnMessageReceived(ExchangeContext * ec, const PacketHeader & p // Call delegate to indicate session establishment failure. if (err != CHIP_NO_ERROR) { + Clear(); mDelegate->OnSessionEstablishmentError(err); } } diff --git a/src/protocols/secure_channel/CASESession.h b/src/protocols/secure_channel/CASESession.h index e6818a56cf2679..41b38e6d9ab190 100644 --- a/src/protocols/secure_channel/CASESession.h +++ b/src/protocols/secure_channel/CASESession.h @@ -224,7 +224,13 @@ class DLL_EXPORT CASESession : public Messaging::ExchangeDelegate, public Pairin CHIP_ERROR ComputeIPK(const uint16_t sessionID, uint8_t * ipk, size_t ipkLen); void SendErrorMsg(SigmaErrorType errorCode); - void HandleErrorMsg(const System::PacketBufferHandle & msg); + + // This function always returns an error. The error value corresponds to the error in the received message. + // The returned error value helps top level message receiver/dispatcher to close the exchange context + // in a more seamless manner. + CHIP_ERROR HandleErrorMsg(const System::PacketBufferHandle & msg); + + void CloseExchange(); // TODO: Remove this and replace with system method to retrieve current time CHIP_ERROR SetEffectiveTime(void);