Skip to content

Commit

Permalink
Integrate CASE MRP parameters with controller and CASE server (#12738)
Browse files Browse the repository at this point in the history
* Integrate CASE MRP parameters with controller and CASE server

* fix tests

* fix test build
  • Loading branch information
pan-apple authored and pull[bot] committed Feb 23, 2022
1 parent 689fbc5 commit 8d80165
Show file tree
Hide file tree
Showing 12 changed files with 58 additions and 35 deletions.
4 changes: 2 additions & 2 deletions src/app/CASEClient.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,8 @@ CHIP_ERROR CASEClient::EstablishSession(PeerId peer, const Transport::PeerAddres
uint16_t keyID = 0;
ReturnErrorOnFailure(mInitParams.idAllocator->Allocate(keyID));

ReturnErrorOnFailure(
mCASESession.EstablishSession(peerAddress, mInitParams.fabricInfo, peer.GetNodeId(), keyID, exchange, this));
ReturnErrorOnFailure(mCASESession.EstablishSession(peerAddress, mInitParams.fabricInfo, peer.GetNodeId(), keyID, exchange, this,
mInitParams.mrpLocalConfig));
mConnectionSuccessCallback = onConnection;
mConnectionFailureCallback = onFailure;
mConectionContext = context;
Expand Down
2 changes: 2 additions & 0 deletions src/app/CASEClient.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ struct CASEClientInitParams
Messaging::ExchangeManager * exchangeMgr = nullptr;
SessionIDAllocator * idAllocator = nullptr;
FabricInfo * fabricInfo = nullptr;

Optional<ReliableMessageProtocolConfig> mrpLocalConfig = Optional<ReliableMessageProtocolConfig>::Missing();
};

class DLL_EXPORT CASEClient : public SessionEstablishmentDelegate
Expand Down
3 changes: 2 additions & 1 deletion src/app/OperationalDeviceProxy.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,8 @@ bool OperationalDeviceProxy::GetAddress(Inet::IPAddress & addr, uint16_t & port)
CHIP_ERROR OperationalDeviceProxy::EstablishConnection()
{
mCASEClient = mInitParams.clientPool->Allocate(CASEClientInitParams{ mInitParams.sessionManager, mInitParams.exchangeMgr,
mInitParams.idAllocator, mInitParams.fabricInfo });
mInitParams.idAllocator, mInitParams.fabricInfo,
mInitParams.mrpLocalConfig });
ReturnErrorCodeIf(mCASEClient == nullptr, CHIP_ERROR_NO_MEMORY);
CHIP_ERROR err =
mCASEClient->EstablishSession(mPeerId, mDeviceAddress, mMRPConfig, HandleCASEConnected, HandleCASEConnectionFailure, this);
Expand Down
2 changes: 2 additions & 0 deletions src/app/OperationalDeviceProxy.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,8 @@ struct DeviceProxyInitParams

Controller::DeviceControllerInteractionModelDelegate * imDelegate = nullptr;

Optional<ReliableMessageProtocolConfig> mrpLocalConfig = Optional<ReliableMessageProtocolConfig>::Missing();

CHIP_ERROR Validate()
{
ReturnErrorCodeIf(sessionManager == nullptr, CHIP_ERROR_INCORRECT_STATE);
Expand Down
1 change: 1 addition & 0 deletions src/controller/CHIPDeviceController.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,7 @@ CHIP_ERROR DeviceController::Init(ControllerInitParams params)
.fabricInfo = params.systemState->Fabrics()->FindFabricWithIndex(mFabricIndex),
.clientPool = &mCASEClientPool,
.imDelegate = params.systemState->IMDelegate(),
.mrpLocalConfig = Optional<ReliableMessageProtocolConfig>::Value(mMRPConfig),
};

CASESessionManagerConfig sessionManagerConfig = {
Expand Down
4 changes: 2 additions & 2 deletions src/controller/CHIPDeviceController.h
Original file line number Diff line number Diff line change
Expand Up @@ -388,6 +388,8 @@ class DLL_EXPORT DeviceController : public SessionReleaseDelegate,

uint16_t mVendorId;

ReliableMessageProtocolConfig mMRPConfig = gDefaultMRPConfig;

//////////// SessionReleaseDelegate Implementation ///////////////
void OnSessionReleased(SessionHandle session) override;

Expand Down Expand Up @@ -819,8 +821,6 @@ class DLL_EXPORT DeviceCommissioner : public DeviceController,

Callback::Callback<OnNOCChainGeneration> mDeviceNOCChainCallback;
SetUpCodePairer mSetUpCodePairer;

ReliableMessageProtocolConfig mMRPConfig = gDefaultMRPConfig;
};

} // namespace Controller
Expand Down
3 changes: 2 additions & 1 deletion src/protocols/secure_channel/CASEServer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,8 @@ CHIP_ERROR CASEServer::InitCASEHandshake(Messaging::ExchangeContext * ec)
ReturnErrorOnFailure(mIDAllocator->Allocate(mSessionKeyId));

// Setup CASE state machine using the credentials for the current fabric.
ReturnErrorOnFailure(GetSession().ListenForSessionEstablishment(mSessionKeyId, mFabrics, this));
ReturnErrorOnFailure(GetSession().ListenForSessionEstablishment(
mSessionKeyId, mFabrics, this, Optional<ReliableMessageProtocolConfig>::Value(gDefaultMRPConfig)));

// Hand over the exchange context to the CASE session.
ec->SetDelegate(&GetSession());
Expand Down
30 changes: 19 additions & 11 deletions src/protocols/secure_channel/CASESession.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -291,12 +291,12 @@ CHIP_ERROR CASESession::DeriveSecureSession(CryptoContext & session, CryptoConte

CHIP_ERROR CASESession::SendSigma1()
{
size_t data_len = TLV::EstimateStructOverhead(kSigmaParamRandomNumberSize, // initiatorRandom
const size_t mrpParamsSize = mLocalMRPConfig.HasValue() ? TLV::EstimateStructOverhead(sizeof(uint16_t), sizeof(uint16_t)) : 0;
size_t data_len = TLV::EstimateStructOverhead(kSigmaParamRandomNumberSize, // initiatorRandom
sizeof(uint16_t), // initiatorSessionId,
kSHA256_Hash_Length, // destinationId
kP256_PublicKey_Length, // InitiatorEphPubKey,
/* TLV::EstimateStructOverhead(sizeof(uint16_t),
sizeof(uint16)_t), // initiatorMRPParams */
mrpParamsSize, // initiatorMRPParams
kCASEResumptionIDSize, CHIP_CRYPTO_AEAD_MIC_LENGTH_BYTES);

System::PacketBufferTLVWriter tlvWriter;
Expand Down Expand Up @@ -461,8 +461,9 @@ CHIP_ERROR CASESession::HandleSigma1(System::PacketBufferHandle && msg)

CHIP_ERROR CASESession::SendSigma2Resume(const ByteSpan & initiatorRandom)
{
size_t max_sigma2_resume_data_len = TLV::EstimateStructOverhead(kCASEResumptionIDSize, CHIP_CRYPTO_AEAD_MIC_LENGTH_BYTES,
sizeof(uint16_t) /*, kMRPOptionalParamsLength, */);
const size_t mrpParamsSize = mLocalMRPConfig.HasValue() ? TLV::EstimateStructOverhead(sizeof(uint16_t), sizeof(uint16_t)) : 0;
size_t max_sigma2_resume_data_len =
TLV::EstimateStructOverhead(kCASEResumptionIDSize, CHIP_CRYPTO_AEAD_MIC_LENGTH_BYTES, sizeof(uint16_t), mrpParamsSize);

System::PacketBufferTLVWriter tlvWriter;
System::PacketBufferHandle msg_R2_resume;
Expand Down Expand Up @@ -603,8 +604,9 @@ CHIP_ERROR CASESession::SendSigma2()
CHIP_CRYPTO_AEAD_MIC_LENGTH_BYTES));

// Construct Sigma2 Msg
size_t data_len = TLV::EstimateStructOverhead(kSigmaParamRandomNumberSize, sizeof(uint16_t), kP256_PublicKey_Length,
msg_r2_signed_enc_len, CHIP_CRYPTO_AEAD_MIC_LENGTH_BYTES);
const size_t mrpParamsSize = mLocalMRPConfig.HasValue() ? TLV::EstimateStructOverhead(sizeof(uint16_t), sizeof(uint16_t)) : 0;
size_t data_len = TLV::EstimateStructOverhead(kSigmaParamRandomNumberSize, sizeof(uint16_t), kP256_PublicKey_Length,
msg_r2_signed_enc_len, CHIP_CRYPTO_AEAD_MIC_LENGTH_BYTES, mrpParamsSize);

System::PacketBufferHandle msg_R2 = System::PacketBufferHandle::New(data_len);
VerifyOrReturnError(!msg_R2.IsNull(), CHIP_ERROR_NO_MEMORY);
Expand All @@ -623,7 +625,7 @@ CHIP_ERROR CASESession::SendSigma2()
if (mLocalMRPConfig.HasValue())
{
ChipLogDetail(SecureChannel, "Including MRP parameters");
ReturnErrorOnFailure(EncodeMRPParameters(TLV::ContextTag(5), mLocalMRPConfig.Value(), tlvWriter));
ReturnErrorOnFailure(EncodeMRPParameters(TLV::ContextTag(5), mLocalMRPConfig.Value(), tlvWriterMsg2));
}
ReturnErrorOnFailure(tlvWriterMsg2.EndContainer(outerContainerType));
ReturnErrorOnFailure(tlvWriterMsg2.Finalize(&msg_R2));
Expand Down Expand Up @@ -680,7 +682,10 @@ CHIP_ERROR CASESession::HandleSigma2Resume(System::PacketBufferHandle && msg)
VerifyOrExit(TLV::TagNumFromTag(tlvReader.GetTag()) == ++decodeTagIdSeq, err = CHIP_ERROR_INVALID_TLV_TAG);
SuccessOrExit(err = tlvReader.Get(responderSessionId));

SuccessOrExit(err = DecodeMRPParametersIfPresent(tlvReader));
if (tlvReader.Next() != CHIP_END_OF_TLV)
{
SuccessOrExit(err = DecodeMRPParametersIfPresent(TLV::ContextTag(4), tlvReader));
}

ChipLogDetail(SecureChannel, "Peer assigned session session ID %d", responderSessionId);
SetPeerSessionId(responderSessionId);
Expand Down Expand Up @@ -852,7 +857,10 @@ CHIP_ERROR CASESession::HandleSigma2(System::PacketBufferHandle && msg)
SetPeerCATs(peerCATs);

// Retrieve responderMRPParams if present
SuccessOrExit(err = DecodeMRPParametersIfPresent(tlvReader));
if (tlvReader.Next() != CHIP_END_OF_TLV)
{
SuccessOrExit(err = DecodeMRPParametersIfPresent(TLV::ContextTag(5), tlvReader));
}

exit:
if (err != CHIP_NO_ERROR)
Expand Down Expand Up @@ -1378,7 +1386,7 @@ CHIP_ERROR CASESession::ParseSigma1(TLV::ContiguousBufferTLVReader & tlvReader,
CHIP_ERROR err = tlvReader.Next();
if (err == CHIP_NO_ERROR && tlvReader.GetTag() == ContextTag(kInitiatorMRPParamsTag))
{
ReturnErrorOnFailure(DecodeMRPParametersIfPresent(tlvReader));
ReturnErrorOnFailure(DecodeMRPParametersIfPresent(TLV::ContextTag(kInitiatorMRPParamsTag), tlvReader));
err = tlvReader.Next();
}

Expand Down
15 changes: 12 additions & 3 deletions src/protocols/secure_channel/PASESession.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -442,7 +442,10 @@ CHIP_ERROR PASESession::HandlePBKDFParamRequest(System::PacketBufferHandle && ms
VerifyOrExit(TLV::TagNumFromTag(tlvReader.GetTag()) == ++decodeTagIdSeq, err = CHIP_ERROR_INVALID_TLV_TAG);
SuccessOrExit(err = tlvReader.Get(hasPBKDFParameters));

SuccessOrExit(err = DecodeMRPParametersIfPresent(tlvReader));
if (tlvReader.Next() != CHIP_END_OF_TLV)
{
SuccessOrExit(err = DecodeMRPParametersIfPresent(TLV::ContextTag(5), tlvReader));
}

err = SendPBKDFParamResponse(ByteSpan(initiatorRandom), hasPBKDFParameters);
SuccessOrExit(err);
Expand Down Expand Up @@ -562,7 +565,10 @@ CHIP_ERROR PASESession::HandlePBKDFParamResponse(System::PacketBufferHandle && m

if (mHavePBKDFParameters)
{
SuccessOrExit(err = DecodeMRPParametersIfPresent(tlvReader));
if (tlvReader.Next() != CHIP_END_OF_TLV)
{
SuccessOrExit(err = DecodeMRPParametersIfPresent(TLV::ContextTag(5), tlvReader));
}

// TODO - Add a unit test that exercises mHavePBKDFParameters path
err = SetupSpake2p(mIterationCount, ByteSpan(mSalt, mSaltLength));
Expand All @@ -585,7 +591,10 @@ CHIP_ERROR PASESession::HandlePBKDFParamResponse(System::PacketBufferHandle && m

SuccessOrExit(err = tlvReader.ExitContainer(containerType));

SuccessOrExit(err = DecodeMRPParametersIfPresent(tlvReader));
if (tlvReader.Next() != CHIP_END_OF_TLV)
{
SuccessOrExit(err = DecodeMRPParametersIfPresent(TLV::ContextTag(5), tlvReader));
}

err = SetupSpake2p(iterCount, ByteSpan(salt, saltLength));
SuccessOrExit(err);
Expand Down
10 changes: 4 additions & 6 deletions src/transport/PairingSession.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,15 +36,13 @@ CHIP_ERROR PairingSession::EncodeMRPParameters(TLV::Tag tag, const ReliableMessa
return tlvWriter.EndContainer(mrpParamsContainer);
}

CHIP_ERROR PairingSession::DecodeMRPParametersIfPresent(TLV::ContiguousBufferTLVReader & tlvReader)
CHIP_ERROR PairingSession::DecodeMRPParametersIfPresent(TLV::Tag expectedTag, TLV::ContiguousBufferTLVReader & tlvReader)
{
// The MRP parameters are optional.
CHIP_ERROR err = tlvReader.Next();
if (err == CHIP_END_OF_TLV)
if (tlvReader.GetTag() != expectedTag)
{
return CHIP_NO_ERROR;
}
ReturnErrorOnFailure(err);

TLV::TLVType containerType = TLV::kTLVType_Structure;
ReturnErrorOnFailure(tlvReader.EnterContainer(containerType));
Expand All @@ -63,10 +61,10 @@ CHIP_ERROR PairingSession::DecodeMRPParametersIfPresent(TLV::ContiguousBufferTLV
mMRPConfig.mIdleRetransTimeout = System::Clock::Milliseconds32(tlvElementValue);

// The next element is optional. If it's not present, return CHIP_NO_ERROR.
err = tlvReader.Next();
CHIP_ERROR err = tlvReader.Next();
if (err == CHIP_END_OF_TLV)
{
return CHIP_NO_ERROR;
return tlvReader.ExitContainer(containerType);
}
ReturnErrorOnFailure(err);
}
Expand Down
4 changes: 2 additions & 2 deletions src/transport/PairingSession.h
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ class DLL_EXPORT PairingSession
}

/**
* Try to decode the next element (pointed by the TLV reader) as MRP parameters.
* Try to decode the current element (pointed by the TLV reader) as MRP parameters.
* If the MRP parameters are found, mMRPConfig is updated with the devoded values.
*
* MRP parameters are optional. So, if the TLV reader is not pointing to the MRP parameters,
Expand All @@ -172,7 +172,7 @@ class DLL_EXPORT PairingSession
* If the parameters are present, but TLV reader fails to correctly parse it, the function will
* return the corresponding error.
*/
CHIP_ERROR DecodeMRPParametersIfPresent(TLV::ContiguousBufferTLVReader & tlvReader);
CHIP_ERROR DecodeMRPParametersIfPresent(TLV::Tag expectedTag, TLV::ContiguousBufferTLVReader & tlvReader);

// TODO: remove Clear, we should create a new instance instead reset the old instance.
void Clear()
Expand Down
15 changes: 8 additions & 7 deletions src/transport/tests/TestPairingSession.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,9 @@ class TestPairingSession : public PairingSession
public:
CHIP_ERROR DeriveSecureSession(CryptoContext & session, CryptoContext::SessionRole role) override { return CHIP_NO_ERROR; }

CHIP_ERROR DecodeMRPParametersIfPresent(System::PacketBufferTLVReader & tlvReader)
CHIP_ERROR DecodeMRPParametersIfPresent(TLV::Tag expectedTag, System::PacketBufferTLVReader & tlvReader)
{
return PairingSession::DecodeMRPParametersIfPresent(tlvReader);
return PairingSession::DecodeMRPParametersIfPresent(expectedTag, tlvReader);
}
};

Expand All @@ -62,8 +62,7 @@ void PairingSessionEncodeDecodeMRPParams(nlTestSuite * inSuite, void * inContext
TLV::TLVType outerContainerType = TLV::kTLVType_NotSpecified;
NL_TEST_ASSERT(inSuite, writer.StartContainer(TLV::AnonymousTag, TLV::kTLVType_Structure, outerContainerType) == CHIP_NO_ERROR);

CHIP_ERROR err = PairingSession::EncodeMRPParameters(TLV::ContextTag(1), config, writer);
NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR);
NL_TEST_ASSERT(inSuite, PairingSession::EncodeMRPParameters(TLV::ContextTag(1), config, writer) == CHIP_NO_ERROR);

NL_TEST_ASSERT(inSuite, writer.EndContainer(outerContainerType) == CHIP_NO_ERROR);
NL_TEST_ASSERT(inSuite, writer.Finalize(&buf) == CHIP_NO_ERROR);
Expand All @@ -75,8 +74,8 @@ void PairingSessionEncodeDecodeMRPParams(nlTestSuite * inSuite, void * inContext
NL_TEST_ASSERT(inSuite, reader.Next(containerType, TLV::AnonymousTag) == CHIP_NO_ERROR);
NL_TEST_ASSERT(inSuite, reader.EnterContainer(containerType) == CHIP_NO_ERROR);

err = session.DecodeMRPParametersIfPresent(reader);
NL_TEST_ASSERT(inSuite, err == CHIP_NO_ERROR);
NL_TEST_ASSERT(inSuite, reader.Next() == CHIP_NO_ERROR);
NL_TEST_ASSERT(inSuite, session.DecodeMRPParametersIfPresent(TLV::ContextTag(1), reader) == CHIP_NO_ERROR);

NL_TEST_ASSERT(inSuite, session.GetMRPConfig().mIdleRetransTimeout == config.mIdleRetransTimeout);
NL_TEST_ASSERT(inSuite, session.GetMRPConfig().mActiveRetransTimeout == config.mActiveRetransTimeout);
Expand All @@ -92,6 +91,7 @@ void PairingSessionTryDecodeMissingMRPParams(nlTestSuite * inSuite, void * inCon

TLV::TLVType outerContainerType = TLV::kTLVType_NotSpecified;
NL_TEST_ASSERT(inSuite, writer.StartContainer(TLV::AnonymousTag, TLV::kTLVType_Structure, outerContainerType) == CHIP_NO_ERROR);
NL_TEST_ASSERT(inSuite, writer.Put(TLV::ContextTag(1), static_cast<uint16_t>(0x1234)) == CHIP_NO_ERROR);
NL_TEST_ASSERT(inSuite, writer.EndContainer(outerContainerType) == CHIP_NO_ERROR);
NL_TEST_ASSERT(inSuite, writer.Finalize(&buf) == CHIP_NO_ERROR);

Expand All @@ -101,7 +101,8 @@ void PairingSessionTryDecodeMissingMRPParams(nlTestSuite * inSuite, void * inCon
reader.Init(std::move(buf));
NL_TEST_ASSERT(inSuite, reader.Next(containerType, TLV::AnonymousTag) == CHIP_NO_ERROR);
NL_TEST_ASSERT(inSuite, reader.EnterContainer(containerType) == CHIP_NO_ERROR);
NL_TEST_ASSERT(inSuite, session.DecodeMRPParametersIfPresent(reader) == CHIP_NO_ERROR);
NL_TEST_ASSERT(inSuite, reader.Next() == CHIP_NO_ERROR);
NL_TEST_ASSERT(inSuite, session.DecodeMRPParametersIfPresent(TLV::ContextTag(2), reader) == CHIP_NO_ERROR);

NL_TEST_ASSERT(inSuite, session.GetMRPConfig().mIdleRetransTimeout == gDefaultMRPConfig.mIdleRetransTimeout);
NL_TEST_ASSERT(inSuite, session.GetMRPConfig().mActiveRetransTimeout == gDefaultMRPConfig.mActiveRetransTimeout);
Expand Down

0 comments on commit 8d80165

Please sign in to comment.