From 2f91737a5842558f941ff38401c0ace1b06a4bb3 Mon Sep 17 00:00:00 2001 From: Pankaj Garg Date: Thu, 9 Sep 2021 13:56:54 -0700 Subject: [PATCH 1/2] Update PASE state machine to match the latest specifications * Use TLV formatted messages * Use StatusReport to indicate errors and completion --- src/protocols/secure_channel/Constants.h | 7 +- src/protocols/secure_channel/PASESession.cpp | 527 +++++++++++------- src/protocols/secure_channel/PASESession.h | 20 +- .../SessionEstablishmentExchangeDispatch.cpp | 1 + 4 files changed, 331 insertions(+), 224 deletions(-) diff --git a/src/protocols/secure_channel/Constants.h b/src/protocols/secure_channel/Constants.h index 48825aa182613d..ed0f4078b2a3b7 100644 --- a/src/protocols/secure_channel/Constants.h +++ b/src/protocols/secure_channel/Constants.h @@ -72,7 +72,12 @@ enum class MsgType : uint8_t }; // Placeholder value for the ProtocolCode field when the GeneralCode is Success or Continue. -constexpr uint16_t kProtocolCodeSuccess = 0x0000; +constexpr uint16_t kProtocolCodeSuccess = 0x0000; +constexpr uint16_t kProtocolCodeNoSharedRoot = 0x0001; +constexpr uint16_t kProtocolCodeInvalidParam = 0x0002; +constexpr uint16_t kProtocolCodeCloseSession = 0x0003; +constexpr uint16_t kProtocolCodeBusy = 0x0004; +constexpr uint16_t kProtocolCodeSessionNotFound = 0x0005; // Placeholder value for the ProtocolCode field when there is no additional protocol-specific code to provide more information. constexpr uint16_t kProtocolCodeGeneralFailure = 0xFFFF; diff --git a/src/protocols/secure_channel/PASESession.cpp b/src/protocols/secure_channel/PASESession.cpp index 7eaef6e74d8eab..d7f35b99c2841a 100644 --- a/src/protocols/secure_channel/PASESession.cpp +++ b/src/protocols/secure_channel/PASESession.cpp @@ -42,13 +42,16 @@ #include #include #include +#include #include +#include #include namespace chip { using namespace Crypto; using namespace Messaging; +using namespace Protocols::SecureChannel; const char * kSpake2pContext = "CHIP PAKE V1 Commissioning"; const char * kSpake2pI2RSessionInfo = "Commissioning I2R Key"; @@ -82,7 +85,7 @@ void PASESession::Clear() memset(&mPoint[0], 0, sizeof(mPoint)); memset(&mPASEVerifier, 0, sizeof(mPASEVerifier)); memset(&mKe[0], 0, sizeof(mKe)); - mNextExpectedMsg = Protocols::SecureChannel::MsgType::PASE_Spake2pError; + mNextExpectedMsg = MsgType::StatusReport; // Note: we don't need to explicitly clear the state of mSpake2p object. // Clearing the following state takes care of it. @@ -273,7 +276,7 @@ CHIP_ERROR PASESession::WaitForPairing(uint32_t mySetUpPINCode, uint32_t pbkdf2I mIterationCount = pbkdf2IterCount; - mNextExpectedMsg = Protocols::SecureChannel::MsgType::PBKDFParamRequest; + mNextExpectedMsg = MsgType::PBKDFParamRequest; mPairingComplete = false; mPasscodeID = 0; @@ -346,133 +349,200 @@ CHIP_ERROR PASESession::DeriveSecureSession(SecureSession & session, SecureSessi CHIP_ERROR PASESession::SendPBKDFParamRequest() { - System::PacketBufferHandle req = System::PacketBufferHandle::New(kPBKDFParamRandomNumberSize); + uint8_t initiatorRandom[kPBKDFParamRandomNumberSize] = { 0 }; + ReturnErrorOnFailure(DRBG_get_bytes(initiatorRandom, kPBKDFParamRandomNumberSize)); + + size_t data_len = + EstimateTLVStructOverhead(kPBKDFParamRandomNumberSize + sizeof(uint16_t) + sizeof(uint16_t) + sizeof(uint8_t), 4); + System::PacketBufferHandle req = System::PacketBufferHandle::New(data_len); VerifyOrReturnError(!req.IsNull(), CHIP_ERROR_NO_MEMORY); - ReturnErrorOnFailure(DRBG_get_bytes(req->Start(), kPBKDFParamRandomNumberSize)); + System::PacketBufferTLVWriter tlvWriter; + tlvWriter.Init(std::move(req)); - req->SetDataLength(kPBKDFParamRandomNumberSize); + TLV::TLVType outerContainerType = TLV::kTLVType_NotSpecified; + ReturnErrorOnFailure(tlvWriter.StartContainer(TLV::AnonymousTag, TLV::kTLVType_Structure, outerContainerType)); + ReturnErrorOnFailure(tlvWriter.PutBytes(TLV::ContextTag(1), initiatorRandom, sizeof(initiatorRandom))); + ReturnErrorOnFailure(tlvWriter.Put(TLV::ContextTag(2), GetLocalKeyId(), true)); + ReturnErrorOnFailure(tlvWriter.Put(TLV::ContextTag(3), mPasscodeID, true)); + ReturnErrorOnFailure(tlvWriter.PutBoolean(TLV::ContextTag(4), mHavePBKDFParameters)); + ReturnErrorOnFailure(tlvWriter.EndContainer(outerContainerType)); + ReturnErrorOnFailure(tlvWriter.Finalize(&req)); // Update commissioning hash with the pbkdf2 param request that's being sent. ReturnErrorOnFailure(mCommissioningHash.AddData(ByteSpan{ req->Start(), req->DataLength() })); - mNextExpectedMsg = Protocols::SecureChannel::MsgType::PBKDFParamResponse; + mNextExpectedMsg = MsgType::PBKDFParamResponse; - ReturnErrorOnFailure(mExchangeCtxt->SendMessage(Protocols::SecureChannel::MsgType::PBKDFParamRequest, std::move(req), - SendFlags(SendMessageFlags::kExpectResponse))); + ReturnErrorOnFailure( + mExchangeCtxt->SendMessage(MsgType::PBKDFParamRequest, std::move(req), SendFlags(SendMessageFlags::kExpectResponse))); ChipLogDetail(SecureChannel, "Sent PBKDF param request"); return CHIP_NO_ERROR; } -CHIP_ERROR PASESession::HandlePBKDFParamRequest(const System::PacketBufferHandle & msg) +CHIP_ERROR PASESession::HandlePBKDFParamRequest(System::PacketBufferHandle && msg) { CHIP_ERROR err = CHIP_NO_ERROR; - // Request message processing - const uint8_t * req = msg->Start(); - size_t reqlen = msg->DataLength(); + System::PacketBufferTLVReader tlvReader; + TLV::TLVType containerType = TLV::kTLVType_Structure; + + uint16_t initiatorSessionId = 0; + uint8_t initiatorRandom[kPBKDFParamRandomNumberSize]; - VerifyOrExit(req != nullptr, err = CHIP_ERROR_MESSAGE_INCOMPLETE); - VerifyOrExit(reqlen == kPBKDFParamRandomNumberSize, err = CHIP_ERROR_INVALID_MESSAGE_LENGTH); + uint32_t decodeTagIdSeq = 0; + bool hasPBKDFParameters = false; ChipLogDetail(SecureChannel, "Received PBKDF param request"); - // Update commissioning hash with the received pbkdf2 param request - err = mCommissioningHash.AddData(ByteSpan{ req, reqlen }); - SuccessOrExit(err); + SuccessOrExit(err = mCommissioningHash.AddData(ByteSpan{ msg->Start(), msg->DataLength() })); + + tlvReader.Init(std::move(msg)); + SuccessOrExit(err = tlvReader.Next(containerType, TLV::AnonymousTag)); + SuccessOrExit(err = tlvReader.EnterContainer(containerType)); + + SuccessOrExit(err = tlvReader.Next()); + VerifyOrExit(TLV::TagNumFromTag(tlvReader.GetTag()) == ++decodeTagIdSeq, err = CHIP_ERROR_INVALID_TLV_TAG); + SuccessOrExit(err = tlvReader.GetBytes(initiatorRandom, sizeof(initiatorRandom))); - err = SendPBKDFParamResponse(); + SuccessOrExit(err = tlvReader.Next()); + VerifyOrExit(TLV::TagNumFromTag(tlvReader.GetTag()) == ++decodeTagIdSeq, err = CHIP_ERROR_INVALID_TLV_TAG); + SuccessOrExit(err = tlvReader.Get(initiatorSessionId)); + + ChipLogDetail(SecureChannel, "Peer assigned session key ID %d", initiatorSessionId); + SetPeerKeyId(initiatorSessionId); + + SuccessOrExit(err = tlvReader.Next()); + VerifyOrExit(TLV::TagNumFromTag(tlvReader.GetTag()) == ++decodeTagIdSeq, err = CHIP_ERROR_INVALID_TLV_TAG); + SuccessOrExit(err = tlvReader.Get(mPasscodeID)); + + SuccessOrExit(err = tlvReader.Next()); + VerifyOrExit(TLV::TagNumFromTag(tlvReader.GetTag()) == ++decodeTagIdSeq, err = CHIP_ERROR_INVALID_TLV_TAG); + SuccessOrExit(err = tlvReader.Get(hasPBKDFParameters)); + + err = SendPBKDFParamResponse(ByteSpan(initiatorRandom), hasPBKDFParameters); SuccessOrExit(err); exit: if (err != CHIP_NO_ERROR) { - SendErrorMsg(Spake2pErrorType::kUnexpected); + SendStatusReport(kProtocolCodeInvalidParam); } return err; } -CHIP_ERROR PASESession::SendPBKDFParamResponse() +CHIP_ERROR PASESession::SendPBKDFParamResponse(ByteSpan initiatorRandom, bool initiatorHasPBKDFParams) { - System::PacketBufferHandle resp; - static_assert(CHAR_BIT == 8, "Assuming sizeof() returns octets here and for sizeof(mPoint)"); - size_t resplen = kPBKDFParamRandomNumberSize + sizeof(uint64_t) + sizeof(uint32_t) + mSaltLength; + uint8_t responderRandom[kPBKDFParamRandomNumberSize] = { 0 }; + ReturnErrorOnFailure(DRBG_get_bytes(responderRandom, kPBKDFParamRandomNumberSize)); - size_t sizeof_point = sizeof(mPoint); - - uint8_t * msg = nullptr; - - resp = System::PacketBufferHandle::New(resplen); + size_t data_len = EstimateTLVStructOverhead( + kPBKDFParamRandomNumberSize + kPBKDFParamRandomNumberSize + sizeof(uint16_t) + sizeof(uint16_t) + sizeof(uint8_t), 5); + System::PacketBufferHandle resp = System::PacketBufferHandle::New(data_len); VerifyOrReturnError(!resp.IsNull(), CHIP_ERROR_NO_MEMORY); - msg = resp->Start(); + System::PacketBufferTLVWriter tlvWriter; + tlvWriter.Init(std::move(resp)); - // Fill in the random value - ReturnErrorOnFailure(DRBG_get_bytes(msg, kPBKDFParamRandomNumberSize)); + TLV::TLVType outerContainerType = TLV::kTLVType_NotSpecified; + ReturnErrorOnFailure(tlvWriter.StartContainer(TLV::AnonymousTag, TLV::kTLVType_Structure, outerContainerType)); + ReturnErrorOnFailure(tlvWriter.Put(TLV::ContextTag(1), initiatorRandom)); + ReturnErrorOnFailure(tlvWriter.PutBytes(TLV::ContextTag(2), responderRandom, sizeof(responderRandom))); + ReturnErrorOnFailure(tlvWriter.Put(TLV::ContextTag(3), GetLocalKeyId(), true)); - // Let's construct the rest of the message using BufferWriter + if (!initiatorHasPBKDFParams) { - Encoding::LittleEndian::BufferWriter bbuf(&msg[kPBKDFParamRandomNumberSize], resplen - kPBKDFParamRandomNumberSize); - bbuf.Put64(mIterationCount); - bbuf.Put32(mSaltLength); - bbuf.Put(mSalt, mSaltLength); - VerifyOrReturnError(bbuf.Fit(), CHIP_ERROR_NO_MEMORY); + TLV::TLVType outerContainer; + ReturnErrorOnFailure(tlvWriter.StartContainer(TLV::ContextTag(4), TLV::kTLVType_Structure, outerContainer)); + ReturnErrorOnFailure(tlvWriter.Put(TLV::ContextTag(1), mIterationCount, true)); + ReturnErrorOnFailure(tlvWriter.PutBytes(TLV::ContextTag(2), mSalt, mSaltLength)); + ReturnErrorOnFailure(tlvWriter.EndContainer(outerContainer)); } - resp->SetDataLength(static_cast(resplen)); + ReturnErrorOnFailure(tlvWriter.EndContainer(outerContainerType)); + ReturnErrorOnFailure(tlvWriter.Finalize(&resp)); // Update commissioning hash with the pbkdf2 param response that's being sent. ReturnErrorOnFailure(mCommissioningHash.AddData(ByteSpan{ resp->Start(), resp->DataLength() })); ReturnErrorOnFailure(SetupSpake2p(mIterationCount, ByteSpan(mSalt, mSaltLength))); + + size_t sizeof_point = sizeof(mPoint); ReturnErrorOnFailure(mSpake2p.ComputeL(mPoint, &sizeof_point, mPASEVerifier.mL, kSpake2p_WS_Length)); - mNextExpectedMsg = Protocols::SecureChannel::MsgType::PASE_Spake2p1; + mNextExpectedMsg = MsgType::PASE_Spake2p1; - ReturnErrorOnFailure(mExchangeCtxt->SendMessage(Protocols::SecureChannel::MsgType::PBKDFParamResponse, std::move(resp), - SendFlags(SendMessageFlags::kExpectResponse))); + ReturnErrorOnFailure( + mExchangeCtxt->SendMessage(MsgType::PBKDFParamResponse, std::move(resp), SendFlags(SendMessageFlags::kExpectResponse))); ChipLogDetail(SecureChannel, "Sent PBKDF param response"); return CHIP_NO_ERROR; } -CHIP_ERROR PASESession::HandlePBKDFParamResponse(const System::PacketBufferHandle & msg) +CHIP_ERROR PASESession::HandlePBKDFParamResponse(System::PacketBufferHandle && msg) { + CHIP_ERROR err = CHIP_NO_ERROR; - // Response message processing - const uint8_t * resp = msg->Start(); - size_t resplen = msg->DataLength(); + System::PacketBufferTLVReader tlvReader; + TLV::TLVType containerType = TLV::kTLVType_Structure; + + uint16_t responderSessionId = 0; + uint8_t random[kPBKDFParamRandomNumberSize]; - // This the fixed part of the message. The variable part of the message contains the salt. - // The length of the variable part is determined by the salt length in the fixed header. - static_assert(CHAR_BIT == 8, "Assuming that sizeof returns octets"); - size_t fixed_resplen = kPBKDFParamRandomNumberSize + sizeof(uint64_t) + sizeof(uint32_t); + uint32_t decodeTagIdSeq = 0; + uint32_t iterCount = 0; + uint32_t saltLength = 0; + const uint8_t * salt; ChipLogDetail(SecureChannel, "Received PBKDF param response"); - VerifyOrExit(resp != nullptr, err = CHIP_ERROR_MESSAGE_INCOMPLETE); - VerifyOrExit(resplen >= fixed_resplen, err = CHIP_ERROR_INVALID_MESSAGE_LENGTH); + SuccessOrExit(err = mCommissioningHash.AddData(ByteSpan{ msg->Start(), msg->DataLength() })); - { - // Let's skip the random number portion of the message - const uint8_t * msgptr = &resp[kPBKDFParamRandomNumberSize]; - uint64_t iterCount = chip::Encoding::LittleEndian::Read64(msgptr); - uint32_t saltlen = chip::Encoding::LittleEndian::Read32(msgptr); + tlvReader.Init(std::move(msg)); + SuccessOrExit(err = tlvReader.Next(containerType, TLV::AnonymousTag)); + SuccessOrExit(err = tlvReader.EnterContainer(containerType)); + + SuccessOrExit(err = tlvReader.Next()); + VerifyOrExit(TLV::TagNumFromTag(tlvReader.GetTag()) == ++decodeTagIdSeq, err = CHIP_ERROR_INVALID_TLV_TAG); + // Initiator's random value + SuccessOrExit(err = tlvReader.GetBytes(random, sizeof(random))); + + SuccessOrExit(err = tlvReader.Next()); + VerifyOrExit(TLV::TagNumFromTag(tlvReader.GetTag()) == ++decodeTagIdSeq, err = CHIP_ERROR_INVALID_TLV_TAG); + // Responder's random value + SuccessOrExit(err = tlvReader.GetBytes(random, sizeof(random))); - VerifyOrExit(resplen == fixed_resplen + saltlen, err = CHIP_ERROR_INVALID_MESSAGE_LENGTH); + SuccessOrExit(err = tlvReader.Next()); + VerifyOrExit(TLV::TagNumFromTag(tlvReader.GetTag()) == ++decodeTagIdSeq, err = CHIP_ERROR_INVALID_TLV_TAG); + SuccessOrExit(err = tlvReader.Get(responderSessionId)); - // Specifications allow message to carry a uint64_t sized iteration count. Current APIs are - // limiting it to uint32_t. Let's make sure it'll fit the size limit. - VerifyOrExit(CanCastTo(iterCount), err = CHIP_ERROR_INVALID_MESSAGE_LENGTH); + ChipLogDetail(SecureChannel, "Peer assigned session key ID %d", responderSessionId); + SetPeerKeyId(responderSessionId); - // Update commissioning hash with the received pbkdf2 param response - err = mCommissioningHash.AddData(ByteSpan{ resp, resplen }); + if (mHavePBKDFParameters) + { + err = SetupSpake2p(iterCount, ByteSpan(mSalt, mSaltLength)); SuccessOrExit(err); + } + else + { + SuccessOrExit(err = tlvReader.Next()); + SuccessOrExit(err = tlvReader.EnterContainer(containerType)); + decodeTagIdSeq = 0; + + SuccessOrExit(err = tlvReader.Next()); + VerifyOrExit(TLV::TagNumFromTag(tlvReader.GetTag()) == ++decodeTagIdSeq, err = CHIP_ERROR_INVALID_TLV_TAG); + SuccessOrExit(err = tlvReader.Get(iterCount)); + + SuccessOrExit(err = tlvReader.Next()); + VerifyOrExit(TLV::TagNumFromTag(tlvReader.GetTag()) == ++decodeTagIdSeq, err = CHIP_ERROR_INVALID_TLV_TAG); + saltLength = tlvReader.GetLength(); + SuccessOrExit(err = tlvReader.GetDataPtr(salt)); - err = SetupSpake2p(static_cast(iterCount), ByteSpan(msgptr, saltlen)); + err = SetupSpake2p(iterCount, ByteSpan(salt, saltLength)); SuccessOrExit(err); } @@ -482,38 +552,44 @@ CHIP_ERROR PASESession::HandlePBKDFParamResponse(const System::PacketBufferHandl exit: if (err != CHIP_NO_ERROR) { - SendErrorMsg(Spake2pErrorType::kUnexpected); + SendStatusReport(kProtocolCodeInvalidParam); } return err; } CHIP_ERROR PASESession::SendMsg1() { + size_t data_len = EstimateTLVStructOverhead(kMAX_Point_Length, 1); + System::PacketBufferHandle msg = System::PacketBufferHandle::New(data_len); + VerifyOrReturnError(!msg.IsNull(), CHIP_ERROR_NO_MEMORY); + + System::PacketBufferTLVWriter tlvWriter; + tlvWriter.Init(std::move(msg)); + + TLV::TLVType outerContainerType = TLV::kTLVType_NotSpecified; + ReturnErrorOnFailure(tlvWriter.StartContainer(TLV::AnonymousTag, TLV::kTLVType_Structure, outerContainerType)); + uint8_t X[kMAX_Point_Length]; size_t X_len = sizeof(X); ReturnErrorOnFailure( mSpake2p.BeginProver(nullptr, 0, nullptr, 0, mPASEVerifier.mW0, kSpake2p_WS_Length, mPASEVerifier.mL, kSpake2p_WS_Length)); - ReturnErrorOnFailure(mSpake2p.ComputeRoundOne(NULL, 0, X, &X_len)); + ReturnErrorOnFailure(tlvWriter.Put(TLV::ContextTag(1), ByteSpan(X, X_len))); + ReturnErrorOnFailure(tlvWriter.EndContainer(outerContainerType)); + ReturnErrorOnFailure(tlvWriter.Finalize(&msg)); - Encoding::LittleEndian::PacketBufferWriter bbuf(System::PacketBufferHandle::New(sizeof(uint16_t) + X_len)); - VerifyOrReturnError(!bbuf.IsNull(), CHIP_ERROR_NO_MEMORY); - bbuf.Put16(GetLocalKeyId()); - bbuf.Put(&X[0], X_len); - VerifyOrReturnError(bbuf.Fit(), CHIP_ERROR_NO_MEMORY); - - mNextExpectedMsg = Protocols::SecureChannel::MsgType::PASE_Spake2p2; + mNextExpectedMsg = MsgType::PASE_Spake2p2; // Call delegate to send the Msg1 to peer - ReturnErrorOnFailure(mExchangeCtxt->SendMessage(Protocols::SecureChannel::MsgType::PASE_Spake2p1, bbuf.Finalize(), - SendFlags(SendMessageFlags::kExpectResponse))); + ReturnErrorOnFailure( + mExchangeCtxt->SendMessage(MsgType::PASE_Spake2p1, std::move(msg), SendFlags(SendMessageFlags::kExpectResponse))); ChipLogDetail(SecureChannel, "Sent spake2p msg1"); return CHIP_NO_ERROR; } -CHIP_ERROR PASESession::HandleMsg1_and_SendMsg2(const System::PacketBufferHandle & msg) +CHIP_ERROR PASESession::HandleMsg1_and_SendMsg2(System::PacketBufferHandle && msg) { CHIP_ERROR err = CHIP_NO_ERROR; @@ -523,52 +599,49 @@ CHIP_ERROR PASESession::HandleMsg1_and_SendMsg2(const System::PacketBufferHandle uint8_t verifier[kMAX_Hash_Length]; size_t verifier_len = kMAX_Hash_Length; - uint16_t data_len; // To be initialized once we compute it. - - const uint8_t * buf = msg->Start(); - size_t buf_len = msg->DataLength(); - - uint16_t encryptionKeyId = 0; - ChipLogDetail(SecureChannel, "Received spake2p msg1"); - VerifyOrExit(buf != nullptr, err = CHIP_ERROR_MESSAGE_INCOMPLETE); - VerifyOrExit(buf_len == sizeof(encryptionKeyId) + kMAX_Point_Length, err = CHIP_ERROR_INVALID_MESSAGE_LENGTH); + System::PacketBufferTLVReader tlvReader; + TLV::TLVType containerType = TLV::kTLVType_Structure; - err = mSpake2p.BeginVerifier(nullptr, 0, nullptr, 0, mPASEVerifier.mW0, kSpake2p_WS_Length, mPoint, sizeof(mPoint)); - SuccessOrExit(err); + const uint8_t * X; + size_t X_len = 0; + + tlvReader.Init(std::move(msg)); + SuccessOrExit(err = tlvReader.Next(containerType, TLV::AnonymousTag)); + SuccessOrExit(err = tlvReader.EnterContainer(containerType)); - encryptionKeyId = chip::Encoding::LittleEndian::Read16(buf); - msg->ConsumeHead(sizeof(encryptionKeyId)); + SuccessOrExit(err = tlvReader.Next()); + VerifyOrExit(TLV::TagNumFromTag(tlvReader.GetTag()) == 1, err = CHIP_ERROR_INVALID_TLV_TAG); + X_len = tlvReader.GetLength(); + SuccessOrExit(err = tlvReader.GetDataPtr(X)); + SuccessOrExit( + err = mSpake2p.BeginVerifier(nullptr, 0, nullptr, 0, mPASEVerifier.mW0, kSpake2p_WS_Length, mPoint, sizeof(mPoint))); // Pass Pa to check abort condition. - err = mSpake2p.ComputeRoundOne(msg->Start(), msg->DataLength(), Y, &Y_len); - SuccessOrExit(err); + SuccessOrExit(err = mSpake2p.ComputeRoundOne(X, X_len, Y, &Y_len)); + SuccessOrExit(err = mSpake2p.ComputeRoundTwo(X, X_len, verifier, &verifier_len)); - ChipLogDetail(SecureChannel, "Peer assigned session key ID %d", encryptionKeyId); - SetPeerKeyId(encryptionKeyId); + { + size_t data_len = EstimateTLVStructOverhead(Y_len + verifier_len, 2); - err = mSpake2p.ComputeRoundTwo(msg->Start(), msg->DataLength(), verifier, &verifier_len); - SuccessOrExit(err); + System::PacketBufferHandle msg2 = System::PacketBufferHandle::New(data_len); + VerifyOrExit(!msg2.IsNull(), err = CHIP_ERROR_NO_MEMORY); - // Make sure our addition doesn't overflow. - VerifyOrExit(UINTMAX_MAX - verifier_len >= Y_len, err = CHIP_ERROR_INVALID_MESSAGE_LENGTH); - VerifyOrExit(CanCastTo(Y_len + verifier_len), err = CHIP_ERROR_INVALID_MESSAGE_LENGTH); - data_len = static_cast(sizeof(encryptionKeyId) + Y_len + verifier_len); + System::PacketBufferTLVWriter tlvWriter; + tlvWriter.Init(std::move(msg2)); - { - Encoding::LittleEndian::PacketBufferWriter bbuf(System::PacketBufferHandle::New(data_len)); - VerifyOrExit(!bbuf.IsNull(), err = CHIP_ERROR_NO_MEMORY); - bbuf.Put16(GetLocalKeyId()); - bbuf.Put(&Y[0], Y_len); - bbuf.Put(verifier, verifier_len); - VerifyOrExit(bbuf.Fit(), err = CHIP_ERROR_NO_MEMORY); + TLV::TLVType outerContainerType = TLV::kTLVType_NotSpecified; + SuccessOrExit(err = tlvWriter.StartContainer(TLV::AnonymousTag, TLV::kTLVType_Structure, outerContainerType)); + SuccessOrExit(err = tlvWriter.Put(TLV::ContextTag(1), ByteSpan(Y, Y_len))); + SuccessOrExit(err = tlvWriter.Put(TLV::ContextTag(2), ByteSpan(verifier, verifier_len))); + SuccessOrExit(err = tlvWriter.EndContainer(outerContainerType)); + SuccessOrExit(err = tlvWriter.Finalize(&msg2)); - mNextExpectedMsg = Protocols::SecureChannel::MsgType::PASE_Spake2p3; + mNextExpectedMsg = MsgType::PASE_Spake2p3; // Call delegate to send the Msg2 to peer - err = mExchangeCtxt->SendMessage(Protocols::SecureChannel::MsgType::PASE_Spake2p2, bbuf.Finalize(), - SendFlags(SendMessageFlags::kExpectResponse)); + err = mExchangeCtxt->SendMessage(MsgType::PASE_Spake2p2, std::move(msg2), SendFlags(SendMessageFlags::kExpectResponse)); SuccessOrExit(err); } @@ -578,115 +651,131 @@ CHIP_ERROR PASESession::HandleMsg1_and_SendMsg2(const System::PacketBufferHandle if (err != CHIP_NO_ERROR) { - SendErrorMsg(Spake2pErrorType::kUnexpected); + SendStatusReport(kProtocolCodeInvalidParam); } return err; } -CHIP_ERROR PASESession::HandleMsg2_and_SendMsg3(const System::PacketBufferHandle & msg) +CHIP_ERROR PASESession::HandleMsg2_and_SendMsg3(System::PacketBufferHandle && msg) { CHIP_ERROR err = CHIP_NO_ERROR; uint8_t verifier[kMAX_Hash_Length]; - size_t verifier_len_raw = kMAX_Hash_Length; - uint16_t verifier_len; // To be inited one we check length is small enough - - uint8_t * buf = msg->Start(); - size_t buf_len = msg->DataLength(); + size_t verifier_len = kMAX_Hash_Length; System::PacketBufferHandle resp; - Spake2pErrorType spake2pErr = Spake2pErrorType::kUnexpected; - - uint16_t encryptionKeyId = 0; + uint16_t spake2pErr = kProtocolCodeInvalidParam; ChipLogDetail(SecureChannel, "Received spake2p msg2"); - VerifyOrExit(buf != nullptr, err = CHIP_ERROR_MESSAGE_INCOMPLETE); - VerifyOrExit(buf_len == sizeof(encryptionKeyId) + kMAX_Point_Length + kMAX_Hash_Length, - err = CHIP_ERROR_INVALID_MESSAGE_LENGTH); + System::PacketBufferTLVReader tlvReader; + TLV::TLVType containerType = TLV::kTLVType_Structure; - encryptionKeyId = chip::Encoding::LittleEndian::Read16(buf); - msg->ConsumeHead(sizeof(encryptionKeyId)); - buf = msg->Start(); - buf_len = msg->DataLength(); + const uint8_t * Y; + size_t Y_len = 0; - ChipLogDetail(SecureChannel, "Peer assigned session key ID %d", encryptionKeyId); - SetPeerKeyId(encryptionKeyId); + const uint8_t * peer_verifier; + size_t peer_verifier_len = 0; - err = mSpake2p.ComputeRoundTwo(buf, kMAX_Point_Length, verifier, &verifier_len_raw); - SuccessOrExit(err); - VerifyOrExit(CanCastTo(verifier_len_raw), err = CHIP_ERROR_INVALID_MESSAGE_LENGTH); - verifier_len = static_cast(verifier_len_raw); + uint32_t decodeTagIdSeq = 0; - { - const uint8_t * hash = &buf[kMAX_Point_Length]; - err = mSpake2p.KeyConfirm(hash, kMAX_Hash_Length); - if (err != CHIP_NO_ERROR) - { - spake2pErr = Spake2pErrorType::kInvalidKeyConfirmation; - SuccessOrExit(err); - } + tlvReader.Init(std::move(msg)); + SuccessOrExit(err = tlvReader.Next(containerType, TLV::AnonymousTag)); + SuccessOrExit(err = tlvReader.EnterContainer(containerType)); - err = mSpake2p.GetKeys(mKe, &mKeLen); + SuccessOrExit(err = tlvReader.Next()); + VerifyOrExit(TLV::TagNumFromTag(tlvReader.GetTag()) == ++decodeTagIdSeq, err = CHIP_ERROR_INVALID_TLV_TAG); + Y_len = tlvReader.GetLength(); + SuccessOrExit(err = tlvReader.GetDataPtr(Y)); + + SuccessOrExit(err = tlvReader.Next()); + VerifyOrExit(TLV::TagNumFromTag(tlvReader.GetTag()) == ++decodeTagIdSeq, err = CHIP_ERROR_INVALID_TLV_TAG); + peer_verifier_len = tlvReader.GetLength(); + SuccessOrExit(err = tlvReader.GetDataPtr(peer_verifier)); + + SuccessOrExit(err = mSpake2p.ComputeRoundTwo(Y, Y_len, verifier, &verifier_len)); + + err = mSpake2p.KeyConfirm(peer_verifier, peer_verifier_len); + if (err != CHIP_NO_ERROR) + { + spake2pErr = kProtocolCodeNoSharedRoot; SuccessOrExit(err); } + SuccessOrExit(err = mSpake2p.GetKeys(mKe, &mKeLen)); + { - Encoding::PacketBufferWriter bbuf(System::PacketBufferHandle::New(verifier_len)); - VerifyOrExit(!bbuf.IsNull(), err = CHIP_ERROR_NO_MEMORY); + size_t data_len = EstimateTLVStructOverhead(verifier_len, 1); + + System::PacketBufferHandle msg3 = System::PacketBufferHandle::New(data_len); + VerifyOrExit(!msg3.IsNull(), err = CHIP_ERROR_NO_MEMORY); + + System::PacketBufferTLVWriter tlvWriter; + tlvWriter.Init(std::move(msg3)); - bbuf.Put(verifier, verifier_len); - VerifyOrExit(bbuf.Fit(), err = CHIP_ERROR_NO_MEMORY); + TLV::TLVType outerContainerType = TLV::kTLVType_NotSpecified; + SuccessOrExit(err = tlvWriter.StartContainer(TLV::AnonymousTag, TLV::kTLVType_Structure, outerContainerType)); + SuccessOrExit(err = tlvWriter.Put(TLV::ContextTag(1), ByteSpan(verifier, verifier_len))); + SuccessOrExit(err = tlvWriter.EndContainer(outerContainerType)); + SuccessOrExit(err = tlvWriter.Finalize(&msg3)); + + mNextExpectedMsg = MsgType::StatusReport; // Call delegate to send the Msg3 to peer - err = mExchangeCtxt->SendMessage(Protocols::SecureChannel::MsgType::PASE_Spake2p3, bbuf.Finalize()); + err = mExchangeCtxt->SendMessage(MsgType::PASE_Spake2p3, std::move(msg3), SendFlags(SendMessageFlags::kExpectResponse)); SuccessOrExit(err); } ChipLogDetail(SecureChannel, "Sent spake2p msg3"); - mPairingComplete = true; - - // Forget our exchange, as no additional messages are expected from the peer - mExchangeCtxt = nullptr; - - // Call delegate to indicate pairing completion - mDelegate->OnSessionEstablished(); - exit: if (err != CHIP_NO_ERROR) { - SendErrorMsg(spake2pErr); + SendStatusReport(spake2pErr); } return err; } -CHIP_ERROR PASESession::HandleMsg3(const System::PacketBufferHandle & msg) +CHIP_ERROR PASESession::HandleMsg3(System::PacketBufferHandle && msg) { - CHIP_ERROR err = CHIP_NO_ERROR; - const uint8_t * hash = msg->Start(); - Spake2pErrorType spake2pErr = Spake2pErrorType::kUnexpected; + CHIP_ERROR err = CHIP_NO_ERROR; + uint16_t spake2pErr = kProtocolCodeInvalidParam; ChipLogDetail(SecureChannel, "Received spake2p msg3"); - // We will set NextExpectedMsg to PASE_Spake2pError in all cases + // We will set NextExpectedMsg to StatusReport in all cases // However, when we are using IP rendezvous, we might set it to PASE_Spake2p1. - mNextExpectedMsg = Protocols::SecureChannel::MsgType::PASE_Spake2pError; + mNextExpectedMsg = MsgType::StatusReport; + + System::PacketBufferTLVReader tlvReader; + TLV::TLVType containerType = TLV::kTLVType_Structure; + + const uint8_t * peer_verifier; + size_t peer_verifier_len = 0; - VerifyOrExit(hash != nullptr, err = CHIP_ERROR_MESSAGE_INCOMPLETE); - VerifyOrExit(msg->DataLength() == kMAX_Hash_Length, err = CHIP_ERROR_INVALID_MESSAGE_LENGTH); + tlvReader.Init(std::move(msg)); + SuccessOrExit(err = tlvReader.Next(containerType, TLV::AnonymousTag)); + SuccessOrExit(err = tlvReader.EnterContainer(containerType)); - err = mSpake2p.KeyConfirm(hash, kMAX_Hash_Length); + SuccessOrExit(err = tlvReader.Next()); + VerifyOrExit(TLV::TagNumFromTag(tlvReader.GetTag()) == 1, err = CHIP_ERROR_INVALID_TLV_TAG); + peer_verifier_len = tlvReader.GetLength(); + SuccessOrExit(err = tlvReader.GetDataPtr(peer_verifier)); + + VerifyOrExit(peer_verifier_len == kMAX_Hash_Length, err = CHIP_ERROR_INVALID_MESSAGE_LENGTH); + + err = mSpake2p.KeyConfirm(peer_verifier, peer_verifier_len); if (err != CHIP_NO_ERROR) { - spake2pErr = Spake2pErrorType::kInvalidKeyConfirmation; + spake2pErr = kProtocolCodeNoSharedRoot; SuccessOrExit(err); } - err = mSpake2p.GetKeys(mKe, &mKeLen); - SuccessOrExit(err); + SuccessOrExit(err = mSpake2p.GetKeys(mKe, &mKeLen)); + + SendStatusReport(kProtocolCodeSuccess); mPairingComplete = true; @@ -700,56 +789,65 @@ CHIP_ERROR PASESession::HandleMsg3(const System::PacketBufferHandle & msg) if (err != CHIP_NO_ERROR) { - SendErrorMsg(spake2pErr); + SendStatusReport(spake2pErr); } return err; } -void PASESession::SendErrorMsg(Spake2pErrorType errorCode) +void PASESession::SendStatusReport(uint16_t protocolCode) { - System::PacketBufferHandle msg; - uint16_t msglen = sizeof(Spake2pErrorMsg); - Spake2pErrorMsg * pMsg = nullptr; + GeneralStatusCode generalCode = + (protocolCode == kProtocolCodeSuccess) ? GeneralStatusCode::kSuccess : GeneralStatusCode::kFailure; + uint32_t protocolId = Id.ToFullyQualifiedSpecForm(); + + ChipLogDetail(SecureChannel, "Sending status report"); - msg = System::PacketBufferHandle::New(msglen); - VerifyOrReturn(!msg.IsNull(), ChipLogError(SecureChannel, "Failed to allocate error message")); + StatusReport statusReport(generalCode, protocolId, protocolCode); - pMsg = reinterpret_cast(msg->Start()); - pMsg->error = errorCode; + Encoding::LittleEndian::PacketBufferWriter bbuf(System::PacketBufferHandle::New(statusReport.Size())); + statusReport.WriteToBuffer(bbuf); - msg->SetDataLength(msglen); + System::PacketBufferHandle msg = bbuf.Finalize(); + VerifyOrReturn(!msg.IsNull(), ChipLogError(SecureChannel, "Failed to allocate status report message")); - VerifyOrReturn(mExchangeCtxt->SendMessage(Protocols::SecureChannel::MsgType::PASE_Spake2pError, std::move(msg)) == - CHIP_NO_ERROR, - ChipLogError(SecureChannel, "Failed to send error message")); + VerifyOrReturn(mExchangeCtxt->SendMessage(MsgType::StatusReport, std::move(msg)) == CHIP_NO_ERROR, + ChipLogError(SecureChannel, "Failed to send status report message")); } -CHIP_ERROR PASESession::HandleErrorMsg(const System::PacketBufferHandle & msg) +CHIP_ERROR PASESession::HandleStatusReport(System::PacketBufferHandle && msg) { - ReturnErrorCodeIf(msg->Start() == nullptr || msg->DataLength() != sizeof(Spake2pErrorMsg), CHIP_ERROR_MESSAGE_INCOMPLETE); - - static_assert( - sizeof(Spake2pErrorMsg) == sizeof(uint8_t), - "Assuming size of Spake2pErrorMsg message is 1 octet, so that endian-ness conversion and memory alignment is not needed"); + StatusReport report; + CHIP_ERROR err = report.Parse(std::move(msg)); + ReturnErrorOnFailure(err); - Spake2pErrorMsg * pMsg = reinterpret_cast(msg->Start()); - - CHIP_ERROR err = CHIP_NO_ERROR; - switch (pMsg->error) + if (report.GetGeneralCode() == GeneralStatusCode::kSuccess && report.GetProtocolCode() == kProtocolCodeSuccess) { - case Spake2pErrorType::kInvalidKeyConfirmation: - err = CHIP_ERROR_KEY_CONFIRMATION_FAILED; - break; + mPairingComplete = true; - case Spake2pErrorType::kUnexpected: - err = CHIP_ERROR_INVALID_PASE_PARAMETER; - break; + // Forget our exchange, as no additional messages are expected from the peer + mExchangeCtxt = nullptr; - default: - err = CHIP_ERROR_INTERNAL; - break; - }; - ChipLogError(SecureChannel, "Received error during pairing process. %s", ErrorStr(err)); + // Call delegate to indicate pairing completion + mDelegate->OnSessionEstablished(); + } + else + { + switch (report.GetProtocolCode()) + { + case kProtocolCodeNoSharedRoot: + err = CHIP_ERROR_KEY_CONFIRMATION_FAILED; + break; + + case kProtocolCodeInvalidParam: + err = CHIP_ERROR_INVALID_PASE_PARAMETER; + break; + + default: + err = CHIP_ERROR_INTERNAL; + break; + }; + ChipLogError(SecureChannel, "Received error during pairing process. %s", ErrorStr(err)); + } return err; } @@ -776,8 +874,7 @@ CHIP_ERROR PASESession::ValidateReceivedMessage(ExchangeContext * exchange, cons } VerifyOrReturnError(!msg.IsNull(), CHIP_ERROR_INVALID_ARGUMENT); - VerifyOrReturnError(payloadHeader.HasMessageType(mNextExpectedMsg) || - payloadHeader.HasMessageType(Protocols::SecureChannel::MsgType::PASE_Spake2pError), + VerifyOrReturnError(payloadHeader.HasMessageType(mNextExpectedMsg) || payloadHeader.HasMessageType(MsgType::StatusReport), CHIP_ERROR_INVALID_MESSAGE_TYPE); return CHIP_NO_ERROR; @@ -791,30 +888,30 @@ CHIP_ERROR PASESession::OnMessageReceived(ExchangeContext * exchange, const Pack SetPeerAddress(mMessageDispatch.GetPeerAddress()); - switch (static_cast(payloadHeader.GetMessageType())) + switch (static_cast(payloadHeader.GetMessageType())) { - case Protocols::SecureChannel::MsgType::PBKDFParamRequest: - err = HandlePBKDFParamRequest(msg); + case MsgType::PBKDFParamRequest: + err = HandlePBKDFParamRequest(std::move(msg)); break; - case Protocols::SecureChannel::MsgType::PBKDFParamResponse: - err = HandlePBKDFParamResponse(msg); + case MsgType::PBKDFParamResponse: + err = HandlePBKDFParamResponse(std::move(msg)); break; - case Protocols::SecureChannel::MsgType::PASE_Spake2p1: - err = HandleMsg1_and_SendMsg2(msg); + case MsgType::PASE_Spake2p1: + err = HandleMsg1_and_SendMsg2(std::move(msg)); break; - case Protocols::SecureChannel::MsgType::PASE_Spake2p2: - err = HandleMsg2_and_SendMsg3(msg); + case MsgType::PASE_Spake2p2: + err = HandleMsg2_and_SendMsg3(std::move(msg)); break; - case Protocols::SecureChannel::MsgType::PASE_Spake2p3: - err = HandleMsg3(msg); + case MsgType::PASE_Spake2p3: + err = HandleMsg3(std::move(msg)); break; - case Protocols::SecureChannel::MsgType::PASE_Spake2pError: - err = HandleErrorMsg(msg); + case MsgType::StatusReport: + err = HandleStatusReport(std::move(msg)); break; default: diff --git a/src/protocols/secure_channel/PASESession.h b/src/protocols/secure_channel/PASESession.h index 375738397ec0a0..c02ff0bf9b21e3 100644 --- a/src/protocols/secure_channel/PASESession.h +++ b/src/protocols/secure_channel/PASESession.h @@ -256,19 +256,21 @@ class DLL_EXPORT PASESession : public Messaging::ExchangeDelegate, public Pairin CHIP_ERROR SetupSpake2p(uint32_t pbkdf2IterCount, const ByteSpan & salt); CHIP_ERROR SendPBKDFParamRequest(); - CHIP_ERROR HandlePBKDFParamRequest(const System::PacketBufferHandle & msg); + CHIP_ERROR HandlePBKDFParamRequest(System::PacketBufferHandle && msg); - CHIP_ERROR SendPBKDFParamResponse(); - CHIP_ERROR HandlePBKDFParamResponse(const System::PacketBufferHandle & msg); + CHIP_ERROR SendPBKDFParamResponse(ByteSpan initiatorRandom, bool initiatorHasPBKDFParams); + CHIP_ERROR HandlePBKDFParamResponse(System::PacketBufferHandle && msg); CHIP_ERROR SendMsg1(); - CHIP_ERROR HandleMsg1_and_SendMsg2(const System::PacketBufferHandle & msg); - CHIP_ERROR HandleMsg2_and_SendMsg3(const System::PacketBufferHandle & msg); - CHIP_ERROR HandleMsg3(const System::PacketBufferHandle & msg); + CHIP_ERROR HandleMsg1_and_SendMsg2(System::PacketBufferHandle && msg); + CHIP_ERROR HandleMsg2_and_SendMsg3(System::PacketBufferHandle && msg); + CHIP_ERROR HandleMsg3(System::PacketBufferHandle && msg); - void SendErrorMsg(Spake2pErrorType errorCode); - CHIP_ERROR HandleErrorMsg(const System::PacketBufferHandle & msg); + void SendStatusReport(uint16_t protocolCode); + CHIP_ERROR HandleStatusReport(System::PacketBufferHandle && msg); + + constexpr size_t EstimateTLVStructOverhead(size_t dataLen, size_t nFields) { return dataLen + (sizeof(uint64_t) * nFields); } void CloseExchange(); @@ -292,6 +294,8 @@ class DLL_EXPORT PASESession : public Messaging::ExchangeDelegate, public Pairin bool mComputeVerifier = true; + bool mHavePBKDFParameters = false; + Hash_SHA256_stream mCommissioningHash; uint32_t mIterationCount = 0; uint16_t mSaltLength = 0; diff --git a/src/protocols/secure_channel/SessionEstablishmentExchangeDispatch.cpp b/src/protocols/secure_channel/SessionEstablishmentExchangeDispatch.cpp index 0dfe8a74ba2ceb..4e14810196fa89 100644 --- a/src/protocols/secure_channel/SessionEstablishmentExchangeDispatch.cpp +++ b/src/protocols/secure_channel/SessionEstablishmentExchangeDispatch.cpp @@ -76,6 +76,7 @@ bool SessionEstablishmentExchangeDispatch::MessagePermitted(uint16_t protocol, u case static_cast(Protocols::SecureChannel::MsgType::CASE_SigmaR2): case static_cast(Protocols::SecureChannel::MsgType::CASE_SigmaR3): case static_cast(Protocols::SecureChannel::MsgType::CASE_SigmaErr): + case static_cast(Protocols::SecureChannel::MsgType::StatusReport): return true; default: From d248472c9607a9e4308a5b7516471a69463124d2 Mon Sep 17 00:00:00 2001 From: Pankaj Garg Date: Fri, 10 Sep 2021 08:18:29 -0700 Subject: [PATCH 2/2] Address review comments --- src/messaging/ApplicationExchangeDispatch.cpp | 8 +- src/protocols/secure_channel/Constants.h | 8 +- src/protocols/secure_channel/PASESession.cpp | 100 ++++++++++-------- src/protocols/secure_channel/PASESession.h | 4 +- .../SessionEstablishmentExchangeDispatch.cpp | 8 +- 5 files changed, 68 insertions(+), 60 deletions(-) diff --git a/src/messaging/ApplicationExchangeDispatch.cpp b/src/messaging/ApplicationExchangeDispatch.cpp index 02e7ff3ad131b7..7e7caf59c1e8e0 100644 --- a/src/messaging/ApplicationExchangeDispatch.cpp +++ b/src/messaging/ApplicationExchangeDispatch.cpp @@ -49,10 +49,10 @@ bool ApplicationExchangeDispatch::MessagePermitted(uint16_t protocol, uint8_t ty { case static_cast(Protocols::SecureChannel::MsgType::PBKDFParamRequest): case static_cast(Protocols::SecureChannel::MsgType::PBKDFParamResponse): - case static_cast(Protocols::SecureChannel::MsgType::PASE_Spake2p1): - case static_cast(Protocols::SecureChannel::MsgType::PASE_Spake2p2): - case static_cast(Protocols::SecureChannel::MsgType::PASE_Spake2p3): - case static_cast(Protocols::SecureChannel::MsgType::PASE_Spake2pError): + case static_cast(Protocols::SecureChannel::MsgType::PASE_Pake1): + case static_cast(Protocols::SecureChannel::MsgType::PASE_Pake2): + case static_cast(Protocols::SecureChannel::MsgType::PASE_Pake3): + case static_cast(Protocols::SecureChannel::MsgType::PASE_PakeError): case static_cast(Protocols::SecureChannel::MsgType::CASE_SigmaR1): case static_cast(Protocols::SecureChannel::MsgType::CASE_SigmaR2): case static_cast(Protocols::SecureChannel::MsgType::CASE_SigmaR3): diff --git a/src/protocols/secure_channel/Constants.h b/src/protocols/secure_channel/Constants.h index ed0f4078b2a3b7..387f7ff41a0ecd 100644 --- a/src/protocols/secure_channel/Constants.h +++ b/src/protocols/secure_channel/Constants.h @@ -57,10 +57,10 @@ enum class MsgType : uint8_t // Password-based session establishment Message Types PBKDFParamRequest = 0x20, PBKDFParamResponse = 0x21, - PASE_Spake2p1 = 0x22, - PASE_Spake2p2 = 0x23, - PASE_Spake2p3 = 0x24, - PASE_Spake2pError = 0x2F, + PASE_Pake1 = 0x22, + PASE_Pake2 = 0x23, + PASE_Pake3 = 0x24, + PASE_PakeError = 0x2F, // Certificate-based session establishment Message Types CASE_SigmaR1 = 0x30, diff --git a/src/protocols/secure_channel/PASESession.cpp b/src/protocols/secure_channel/PASESession.cpp index d7f35b99c2841a..1da95010e67bd9 100644 --- a/src/protocols/secure_channel/PASESession.cpp +++ b/src/protocols/secure_channel/PASESession.cpp @@ -85,7 +85,7 @@ void PASESession::Clear() memset(&mPoint[0], 0, sizeof(mPoint)); memset(&mPASEVerifier, 0, sizeof(mPASEVerifier)); memset(&mKe[0], 0, sizeof(mKe)); - mNextExpectedMsg = MsgType::StatusReport; + mNextExpectedMsg = MsgType::PASE_PakeError; // Note: we don't need to explicitly clear the state of mSpake2p object. // Clearing the following state takes care of it. @@ -349,12 +349,11 @@ CHIP_ERROR PASESession::DeriveSecureSession(SecureSession & session, SecureSessi CHIP_ERROR PASESession::SendPBKDFParamRequest() { - uint8_t initiatorRandom[kPBKDFParamRandomNumberSize] = { 0 }; - ReturnErrorOnFailure(DRBG_get_bytes(initiatorRandom, kPBKDFParamRandomNumberSize)); + ReturnErrorOnFailure(DRBG_get_bytes(mPBKDFLocalRandomData, sizeof(mPBKDFLocalRandomData))); - size_t data_len = + const size_t max_msg_len = EstimateTLVStructOverhead(kPBKDFParamRandomNumberSize + sizeof(uint16_t) + sizeof(uint16_t) + sizeof(uint8_t), 4); - System::PacketBufferHandle req = System::PacketBufferHandle::New(data_len); + System::PacketBufferHandle req = System::PacketBufferHandle::New(max_msg_len); VerifyOrReturnError(!req.IsNull(), CHIP_ERROR_NO_MEMORY); System::PacketBufferTLVWriter tlvWriter; @@ -362,10 +361,11 @@ CHIP_ERROR PASESession::SendPBKDFParamRequest() TLV::TLVType outerContainerType = TLV::kTLVType_NotSpecified; ReturnErrorOnFailure(tlvWriter.StartContainer(TLV::AnonymousTag, TLV::kTLVType_Structure, outerContainerType)); - ReturnErrorOnFailure(tlvWriter.PutBytes(TLV::ContextTag(1), initiatorRandom, sizeof(initiatorRandom))); + ReturnErrorOnFailure(tlvWriter.PutBytes(TLV::ContextTag(1), mPBKDFLocalRandomData, sizeof(mPBKDFLocalRandomData))); ReturnErrorOnFailure(tlvWriter.Put(TLV::ContextTag(2), GetLocalKeyId(), true)); ReturnErrorOnFailure(tlvWriter.Put(TLV::ContextTag(3), mPasscodeID, true)); ReturnErrorOnFailure(tlvWriter.PutBoolean(TLV::ContextTag(4), mHavePBKDFParameters)); + // TODO - Add optional MRP parameter support to PASE ReturnErrorOnFailure(tlvWriter.EndContainer(outerContainerType)); ReturnErrorOnFailure(tlvWriter.Finalize(&req)); @@ -411,7 +411,7 @@ CHIP_ERROR PASESession::HandlePBKDFParamRequest(System::PacketBufferHandle && ms VerifyOrExit(TLV::TagNumFromTag(tlvReader.GetTag()) == ++decodeTagIdSeq, err = CHIP_ERROR_INVALID_TLV_TAG); SuccessOrExit(err = tlvReader.Get(initiatorSessionId)); - ChipLogDetail(SecureChannel, "Peer assigned session key ID %d", initiatorSessionId); + ChipLogDetail(SecureChannel, "Peer assigned session ID %d", initiatorSessionId); SetPeerKeyId(initiatorSessionId); SuccessOrExit(err = tlvReader.Next()); @@ -422,6 +422,8 @@ CHIP_ERROR PASESession::HandlePBKDFParamRequest(System::PacketBufferHandle && ms VerifyOrExit(TLV::TagNumFromTag(tlvReader.GetTag()) == ++decodeTagIdSeq, err = CHIP_ERROR_INVALID_TLV_TAG); SuccessOrExit(err = tlvReader.Get(hasPBKDFParameters)); + // TODO - Check if optional MRP parameters were sent. If so, cache them. + err = SendPBKDFParamResponse(ByteSpan(initiatorRandom), hasPBKDFParameters); SuccessOrExit(err); @@ -436,12 +438,11 @@ CHIP_ERROR PASESession::HandlePBKDFParamRequest(System::PacketBufferHandle && ms CHIP_ERROR PASESession::SendPBKDFParamResponse(ByteSpan initiatorRandom, bool initiatorHasPBKDFParams) { - uint8_t responderRandom[kPBKDFParamRandomNumberSize] = { 0 }; - ReturnErrorOnFailure(DRBG_get_bytes(responderRandom, kPBKDFParamRandomNumberSize)); + ReturnErrorOnFailure(DRBG_get_bytes(mPBKDFLocalRandomData, sizeof(mPBKDFLocalRandomData))); - size_t data_len = EstimateTLVStructOverhead( + const size_t max_msg_len = EstimateTLVStructOverhead( kPBKDFParamRandomNumberSize + kPBKDFParamRandomNumberSize + sizeof(uint16_t) + sizeof(uint16_t) + sizeof(uint8_t), 5); - System::PacketBufferHandle resp = System::PacketBufferHandle::New(data_len); + System::PacketBufferHandle resp = System::PacketBufferHandle::New(max_msg_len); VerifyOrReturnError(!resp.IsNull(), CHIP_ERROR_NO_MEMORY); System::PacketBufferTLVWriter tlvWriter; @@ -449,17 +450,18 @@ CHIP_ERROR PASESession::SendPBKDFParamResponse(ByteSpan initiatorRandom, bool in TLV::TLVType outerContainerType = TLV::kTLVType_NotSpecified; ReturnErrorOnFailure(tlvWriter.StartContainer(TLV::AnonymousTag, TLV::kTLVType_Structure, outerContainerType)); + // The initiator random value is being sent back in the response as required by the specifications ReturnErrorOnFailure(tlvWriter.Put(TLV::ContextTag(1), initiatorRandom)); - ReturnErrorOnFailure(tlvWriter.PutBytes(TLV::ContextTag(2), responderRandom, sizeof(responderRandom))); + ReturnErrorOnFailure(tlvWriter.PutBytes(TLV::ContextTag(2), mPBKDFLocalRandomData, sizeof(mPBKDFLocalRandomData))); ReturnErrorOnFailure(tlvWriter.Put(TLV::ContextTag(3), GetLocalKeyId(), true)); if (!initiatorHasPBKDFParams) { - TLV::TLVType outerContainer; - ReturnErrorOnFailure(tlvWriter.StartContainer(TLV::ContextTag(4), TLV::kTLVType_Structure, outerContainer)); + TLV::TLVType pbkdfParamContainer; + ReturnErrorOnFailure(tlvWriter.StartContainer(TLV::ContextTag(4), TLV::kTLVType_Structure, pbkdfParamContainer)); ReturnErrorOnFailure(tlvWriter.Put(TLV::ContextTag(1), mIterationCount, true)); ReturnErrorOnFailure(tlvWriter.PutBytes(TLV::ContextTag(2), mSalt, mSaltLength)); - ReturnErrorOnFailure(tlvWriter.EndContainer(outerContainer)); + ReturnErrorOnFailure(tlvWriter.EndContainer(pbkdfParamContainer)); } ReturnErrorOnFailure(tlvWriter.EndContainer(outerContainerType)); @@ -472,7 +474,7 @@ CHIP_ERROR PASESession::SendPBKDFParamResponse(ByteSpan initiatorRandom, bool in size_t sizeof_point = sizeof(mPoint); ReturnErrorOnFailure(mSpake2p.ComputeL(mPoint, &sizeof_point, mPASEVerifier.mL, kSpake2p_WS_Length)); - mNextExpectedMsg = MsgType::PASE_Spake2p1; + mNextExpectedMsg = MsgType::PASE_Pake1; ReturnErrorOnFailure( mExchangeCtxt->SendMessage(MsgType::PBKDFParamResponse, std::move(resp), SendFlags(SendMessageFlags::kExpectResponse))); @@ -509,6 +511,7 @@ CHIP_ERROR PASESession::HandlePBKDFParamResponse(System::PacketBufferHandle && m VerifyOrExit(TLV::TagNumFromTag(tlvReader.GetTag()) == ++decodeTagIdSeq, err = CHIP_ERROR_INVALID_TLV_TAG); // Initiator's random value SuccessOrExit(err = tlvReader.GetBytes(random, sizeof(random))); + VerifyOrExit(ByteSpan(random).data_equal(ByteSpan(mPBKDFLocalRandomData)), err = CHIP_ERROR_INVALID_PASE_PARAMETER); SuccessOrExit(err = tlvReader.Next()); VerifyOrExit(TLV::TagNumFromTag(tlvReader.GetTag()) == ++decodeTagIdSeq, err = CHIP_ERROR_INVALID_TLV_TAG); @@ -519,12 +522,12 @@ CHIP_ERROR PASESession::HandlePBKDFParamResponse(System::PacketBufferHandle && m VerifyOrExit(TLV::TagNumFromTag(tlvReader.GetTag()) == ++decodeTagIdSeq, err = CHIP_ERROR_INVALID_TLV_TAG); SuccessOrExit(err = tlvReader.Get(responderSessionId)); - ChipLogDetail(SecureChannel, "Peer assigned session key ID %d", responderSessionId); + ChipLogDetail(SecureChannel, "Peer assigned session ID %d", responderSessionId); SetPeerKeyId(responderSessionId); if (mHavePBKDFParameters) { - err = SetupSpake2p(iterCount, ByteSpan(mSalt, mSaltLength)); + err = SetupSpake2p(mIterationCount, ByteSpan(mSalt, mSaltLength)); SuccessOrExit(err); } else @@ -559,8 +562,8 @@ CHIP_ERROR PASESession::HandlePBKDFParamResponse(System::PacketBufferHandle && m CHIP_ERROR PASESession::SendMsg1() { - size_t data_len = EstimateTLVStructOverhead(kMAX_Point_Length, 1); - System::PacketBufferHandle msg = System::PacketBufferHandle::New(data_len); + const size_t max_msg_len = EstimateTLVStructOverhead(kMAX_Point_Length, 1); + System::PacketBufferHandle msg = System::PacketBufferHandle::New(max_msg_len); VerifyOrReturnError(!msg.IsNull(), CHIP_ERROR_NO_MEMORY); System::PacketBufferTLVWriter tlvWriter; @@ -572,24 +575,25 @@ CHIP_ERROR PASESession::SendMsg1() uint8_t X[kMAX_Point_Length]; size_t X_len = sizeof(X); + constexpr uint8_t kPake1_pA = 1; + ReturnErrorOnFailure( mSpake2p.BeginProver(nullptr, 0, nullptr, 0, mPASEVerifier.mW0, kSpake2p_WS_Length, mPASEVerifier.mL, kSpake2p_WS_Length)); ReturnErrorOnFailure(mSpake2p.ComputeRoundOne(NULL, 0, X, &X_len)); - ReturnErrorOnFailure(tlvWriter.Put(TLV::ContextTag(1), ByteSpan(X, X_len))); + ReturnErrorOnFailure(tlvWriter.Put(TLV::ContextTag(kPake1_pA), ByteSpan(X, X_len))); ReturnErrorOnFailure(tlvWriter.EndContainer(outerContainerType)); ReturnErrorOnFailure(tlvWriter.Finalize(&msg)); - mNextExpectedMsg = MsgType::PASE_Spake2p2; + mNextExpectedMsg = MsgType::PASE_Pake2; - // Call delegate to send the Msg1 to peer ReturnErrorOnFailure( - mExchangeCtxt->SendMessage(MsgType::PASE_Spake2p1, std::move(msg), SendFlags(SendMessageFlags::kExpectResponse))); + mExchangeCtxt->SendMessage(MsgType::PASE_Pake1, std::move(msg), SendFlags(SendMessageFlags::kExpectResponse))); ChipLogDetail(SecureChannel, "Sent spake2p msg1"); return CHIP_NO_ERROR; } -CHIP_ERROR PASESession::HandleMsg1_and_SendMsg2(System::PacketBufferHandle && msg) +CHIP_ERROR PASESession::HandleMsg1_and_SendMsg2(System::PacketBufferHandle && msg1) { CHIP_ERROR err = CHIP_NO_ERROR; @@ -607,7 +611,7 @@ CHIP_ERROR PASESession::HandleMsg1_and_SendMsg2(System::PacketBufferHandle && ms const uint8_t * X; size_t X_len = 0; - tlvReader.Init(std::move(msg)); + tlvReader.Init(std::move(msg1)); SuccessOrExit(err = tlvReader.Next(containerType, TLV::AnonymousTag)); SuccessOrExit(err = tlvReader.EnterContainer(containerType)); @@ -618,14 +622,16 @@ CHIP_ERROR PASESession::HandleMsg1_and_SendMsg2(System::PacketBufferHandle && ms SuccessOrExit( err = mSpake2p.BeginVerifier(nullptr, 0, nullptr, 0, mPASEVerifier.mW0, kSpake2p_WS_Length, mPoint, sizeof(mPoint))); - // Pass Pa to check abort condition. SuccessOrExit(err = mSpake2p.ComputeRoundOne(X, X_len, Y, &Y_len)); SuccessOrExit(err = mSpake2p.ComputeRoundTwo(X, X_len, verifier, &verifier_len)); + msg1 = nullptr; { - size_t data_len = EstimateTLVStructOverhead(Y_len + verifier_len, 2); + const size_t max_msg_len = EstimateTLVStructOverhead(Y_len + verifier_len, 2); + constexpr uint8_t kPake2_pB = 1; + constexpr uint8_t kPake2_cB = 2; - System::PacketBufferHandle msg2 = System::PacketBufferHandle::New(data_len); + System::PacketBufferHandle msg2 = System::PacketBufferHandle::New(max_msg_len); VerifyOrExit(!msg2.IsNull(), err = CHIP_ERROR_NO_MEMORY); System::PacketBufferTLVWriter tlvWriter; @@ -633,15 +639,14 @@ CHIP_ERROR PASESession::HandleMsg1_and_SendMsg2(System::PacketBufferHandle && ms TLV::TLVType outerContainerType = TLV::kTLVType_NotSpecified; SuccessOrExit(err = tlvWriter.StartContainer(TLV::AnonymousTag, TLV::kTLVType_Structure, outerContainerType)); - SuccessOrExit(err = tlvWriter.Put(TLV::ContextTag(1), ByteSpan(Y, Y_len))); - SuccessOrExit(err = tlvWriter.Put(TLV::ContextTag(2), ByteSpan(verifier, verifier_len))); + SuccessOrExit(err = tlvWriter.Put(TLV::ContextTag(kPake2_pB), ByteSpan(Y, Y_len))); + SuccessOrExit(err = tlvWriter.Put(TLV::ContextTag(kPake2_cB), ByteSpan(verifier, verifier_len))); SuccessOrExit(err = tlvWriter.EndContainer(outerContainerType)); SuccessOrExit(err = tlvWriter.Finalize(&msg2)); - mNextExpectedMsg = MsgType::PASE_Spake2p3; + mNextExpectedMsg = MsgType::PASE_Pake3; - // Call delegate to send the Msg2 to peer - err = mExchangeCtxt->SendMessage(MsgType::PASE_Spake2p2, std::move(msg2), SendFlags(SendMessageFlags::kExpectResponse)); + err = mExchangeCtxt->SendMessage(MsgType::PASE_Pake2, std::move(msg2), SendFlags(SendMessageFlags::kExpectResponse)); SuccessOrExit(err); } @@ -656,7 +661,7 @@ CHIP_ERROR PASESession::HandleMsg1_and_SendMsg2(System::PacketBufferHandle && ms return err; } -CHIP_ERROR PASESession::HandleMsg2_and_SendMsg3(System::PacketBufferHandle && msg) +CHIP_ERROR PASESession::HandleMsg2_and_SendMsg3(System::PacketBufferHandle && msg2) { CHIP_ERROR err = CHIP_NO_ERROR; @@ -680,7 +685,7 @@ CHIP_ERROR PASESession::HandleMsg2_and_SendMsg3(System::PacketBufferHandle && ms uint32_t decodeTagIdSeq = 0; - tlvReader.Init(std::move(msg)); + tlvReader.Init(std::move(msg2)); SuccessOrExit(err = tlvReader.Next(containerType, TLV::AnonymousTag)); SuccessOrExit(err = tlvReader.EnterContainer(containerType)); @@ -704,11 +709,13 @@ CHIP_ERROR PASESession::HandleMsg2_and_SendMsg3(System::PacketBufferHandle && ms } SuccessOrExit(err = mSpake2p.GetKeys(mKe, &mKeLen)); + msg2 = nullptr; { - size_t data_len = EstimateTLVStructOverhead(verifier_len, 1); + const size_t max_msg_len = EstimateTLVStructOverhead(verifier_len, 1); + constexpr uint8_t kPake3_cB = 1; - System::PacketBufferHandle msg3 = System::PacketBufferHandle::New(data_len); + System::PacketBufferHandle msg3 = System::PacketBufferHandle::New(max_msg_len); VerifyOrExit(!msg3.IsNull(), err = CHIP_ERROR_NO_MEMORY); System::PacketBufferTLVWriter tlvWriter; @@ -716,14 +723,13 @@ CHIP_ERROR PASESession::HandleMsg2_and_SendMsg3(System::PacketBufferHandle && ms TLV::TLVType outerContainerType = TLV::kTLVType_NotSpecified; SuccessOrExit(err = tlvWriter.StartContainer(TLV::AnonymousTag, TLV::kTLVType_Structure, outerContainerType)); - SuccessOrExit(err = tlvWriter.Put(TLV::ContextTag(1), ByteSpan(verifier, verifier_len))); + SuccessOrExit(err = tlvWriter.Put(TLV::ContextTag(kPake3_cB), ByteSpan(verifier, verifier_len))); SuccessOrExit(err = tlvWriter.EndContainer(outerContainerType)); SuccessOrExit(err = tlvWriter.Finalize(&msg3)); mNextExpectedMsg = MsgType::StatusReport; - // Call delegate to send the Msg3 to peer - err = mExchangeCtxt->SendMessage(MsgType::PASE_Spake2p3, std::move(msg3), SendFlags(SendMessageFlags::kExpectResponse)); + err = mExchangeCtxt->SendMessage(MsgType::PASE_Pake3, std::move(msg3), SendFlags(SendMessageFlags::kExpectResponse)); SuccessOrExit(err); } @@ -745,9 +751,8 @@ CHIP_ERROR PASESession::HandleMsg3(System::PacketBufferHandle && msg) ChipLogDetail(SecureChannel, "Received spake2p msg3"); - // We will set NextExpectedMsg to StatusReport in all cases - // However, when we are using IP rendezvous, we might set it to PASE_Spake2p1. - mNextExpectedMsg = MsgType::StatusReport; + // We will set NextExpectedMsg to PASE_PakeError in all cases + mNextExpectedMsg = MsgType::PASE_PakeError; System::PacketBufferTLVReader tlvReader; TLV::TLVType containerType = TLV::kTLVType_Structure; @@ -775,6 +780,7 @@ CHIP_ERROR PASESession::HandleMsg3(System::PacketBufferHandle && msg) SuccessOrExit(err = mSpake2p.GetKeys(mKe, &mKeLen)); + // Send confirmation to peer that we succeeded so they can start using the session. SendStatusReport(kProtocolCodeSuccess); mPairingComplete = true; @@ -898,15 +904,15 @@ CHIP_ERROR PASESession::OnMessageReceived(ExchangeContext * exchange, const Pack err = HandlePBKDFParamResponse(std::move(msg)); break; - case MsgType::PASE_Spake2p1: + case MsgType::PASE_Pake1: err = HandleMsg1_and_SendMsg2(std::move(msg)); break; - case MsgType::PASE_Spake2p2: + case MsgType::PASE_Pake2: err = HandleMsg2_and_SendMsg3(std::move(msg)); break; - case MsgType::PASE_Spake2p3: + case MsgType::PASE_Pake3: err = HandleMsg3(std::move(msg)); break; diff --git a/src/protocols/secure_channel/PASESession.h b/src/protocols/secure_channel/PASESession.h index c02ff0bf9b21e3..fa261b8d5ffcac 100644 --- a/src/protocols/secure_channel/PASESession.h +++ b/src/protocols/secure_channel/PASESession.h @@ -276,7 +276,7 @@ class DLL_EXPORT PASESession : public Messaging::ExchangeDelegate, public Pairin SessionEstablishmentDelegate * mDelegate = nullptr; - Protocols::SecureChannel::MsgType mNextExpectedMsg = Protocols::SecureChannel::MsgType::PASE_Spake2pError; + Protocols::SecureChannel::MsgType mNextExpectedMsg = Protocols::SecureChannel::MsgType::PASE_PakeError; #ifdef ENABLE_HSM_SPAKE Spake2pHSM_P256_SHA256_HKDF_HMAC mSpake2p; @@ -296,6 +296,8 @@ class DLL_EXPORT PASESession : public Messaging::ExchangeDelegate, public Pairin bool mHavePBKDFParameters = false; + uint8_t mPBKDFLocalRandomData[kPBKDFParamRandomNumberSize]; + Hash_SHA256_stream mCommissioningHash; uint32_t mIterationCount = 0; uint16_t mSaltLength = 0; diff --git a/src/protocols/secure_channel/SessionEstablishmentExchangeDispatch.cpp b/src/protocols/secure_channel/SessionEstablishmentExchangeDispatch.cpp index 4e14810196fa89..318e655c35da7e 100644 --- a/src/protocols/secure_channel/SessionEstablishmentExchangeDispatch.cpp +++ b/src/protocols/secure_channel/SessionEstablishmentExchangeDispatch.cpp @@ -68,10 +68,10 @@ bool SessionEstablishmentExchangeDispatch::MessagePermitted(uint16_t protocol, u case static_cast(Protocols::SecureChannel::MsgType::StandaloneAck): case static_cast(Protocols::SecureChannel::MsgType::PBKDFParamRequest): case static_cast(Protocols::SecureChannel::MsgType::PBKDFParamResponse): - case static_cast(Protocols::SecureChannel::MsgType::PASE_Spake2p1): - case static_cast(Protocols::SecureChannel::MsgType::PASE_Spake2p2): - case static_cast(Protocols::SecureChannel::MsgType::PASE_Spake2p3): - case static_cast(Protocols::SecureChannel::MsgType::PASE_Spake2pError): + case static_cast(Protocols::SecureChannel::MsgType::PASE_Pake1): + case static_cast(Protocols::SecureChannel::MsgType::PASE_Pake2): + case static_cast(Protocols::SecureChannel::MsgType::PASE_Pake3): + case static_cast(Protocols::SecureChannel::MsgType::PASE_PakeError): case static_cast(Protocols::SecureChannel::MsgType::CASE_SigmaR1): case static_cast(Protocols::SecureChannel::MsgType::CASE_SigmaR2): case static_cast(Protocols::SecureChannel::MsgType::CASE_SigmaR3):