diff --git a/src/protocols/secure_channel/PairingSession.h b/src/protocols/secure_channel/PairingSession.h index ccbbd4a452ae64..c604cd0662dfd3 100644 --- a/src/protocols/secure_channel/PairingSession.h +++ b/src/protocols/secure_channel/PairingSession.h @@ -162,44 +162,40 @@ class DLL_EXPORT PairingSession : public SessionDelegate CHIP_ERROR HandleStatusReport(System::PacketBufferHandle && msg, bool successExpected) { Protocols::SecureChannel::StatusReport report; - CHIP_ERROR err = report.Parse(std::move(msg)); - ReturnErrorOnFailure(err); + ReturnErrorOnFailure(report.Parse(std::move(msg))); VerifyOrReturnError(report.GetProtocolId() == Protocols::SecureChannel::Id, CHIP_ERROR_INVALID_ARGUMENT); if (report.GetGeneralCode() == Protocols::SecureChannel::GeneralStatusCode::kSuccess && report.GetProtocolCode() == Protocols::SecureChannel::kProtocolCodeSuccess && successExpected) { OnSuccessStatusReport(); + return CHIP_NO_ERROR; } - else - { - err = OnFailureStatusReport(report.GetGeneralCode(), report.GetProtocolCode()); - if (report.GetGeneralCode() == Protocols::SecureChannel::GeneralStatusCode::kBusy && - report.GetProtocolCode() == Protocols::SecureChannel::kProtocolCodeBusy) + if (report.GetGeneralCode() == Protocols::SecureChannel::GeneralStatusCode::kBusy && + report.GetProtocolCode() == Protocols::SecureChannel::kProtocolCodeBusy) + { + if (!report.GetProtocolData().IsNull()) { - if (!report.GetProtocolData().IsNull()) + Encoding::LittleEndian::Reader reader(report.GetProtocolData()->Start(), report.GetProtocolData()->DataLength()); + + uint16_t minimumWaitTime = 0; + CHIP_ERROR waitTimeErr = reader.Read16(&minimumWaitTime).StatusCode(); + if (waitTimeErr != CHIP_NO_ERROR) + { + ChipLogError(SecureChannel, "Failed to read the minimum wait time: %" CHIP_ERROR_FORMAT, waitTimeErr.Format()); + } + else { - Encoding::LittleEndian::Reader reader(report.GetProtocolData()->Start(), - report.GetProtocolData()->DataLength()); - - uint16_t minimumWaitTime = 0; - err = reader.Read16(&minimumWaitTime).StatusCode(); - if (err != CHIP_NO_ERROR) - { - ChipLogError(SecureChannel, "Failed to read the minimum wait time: %" CHIP_ERROR_FORMAT, err.Format()); - } - else - { - // TODO: CASE: Notify minimum wait time to clients on receiving busy status report #28290 - ChipLogProgress(SecureChannel, "Received busy status report with minimum wait time: %u ms", - minimumWaitTime); - } + // TODO: CASE: Notify minimum wait time to clients on receiving busy status report #28290 + ChipLogProgress(SecureChannel, "Received busy status report with minimum wait time: %u ms", minimumWaitTime); } } } - return err; + // It's very important that we propagate the return value from + // OnFailureStatusReport out to the caller. Make sure we return it directly. + return OnFailureStatusReport(report.GetGeneralCode(), report.GetProtocolCode()); } /** diff --git a/src/protocols/secure_channel/tests/TestCASESession.cpp b/src/protocols/secure_channel/tests/TestCASESession.cpp index 26cb1e0f71ed58..c0a6e6ac2c5b48 100644 --- a/src/protocols/secure_channel/tests/TestCASESession.cpp +++ b/src/protocols/secure_channel/tests/TestCASESession.cpp @@ -122,7 +122,14 @@ CHIP_ERROR InitFabricTable(chip::FabricTable & fabricTable, chip::TestPersistent class TestCASESecurePairingDelegate : public SessionEstablishmentDelegate { public: - void OnSessionEstablishmentError(CHIP_ERROR error) override { mNumPairingErrors++; } + void OnSessionEstablishmentError(CHIP_ERROR error) override + { + mNumPairingErrors++; + if (error == CHIP_ERROR_BUSY) + { + mNumBusyResponses++; + } + } void OnSessionEstablished(const SessionHandle & session) override { @@ -137,6 +144,7 @@ class TestCASESecurePairingDelegate : public SessionEstablishmentDelegate // TODO: Rename mNumPairing* to mNumEstablishment* uint32_t mNumPairingErrors = 0; uint32_t mNumPairingComplete = 0; + uint32_t mNumBusyResponses = 0; }; class TestOperationalKeystore : public chip::Crypto::OperationalKeystore @@ -314,6 +322,7 @@ class TestCASESession static void SecurePairingStartTest(nlTestSuite * inSuite, void * inContext); static void SecurePairingHandshakeTest(nlTestSuite * inSuite, void * inContext); static void SecurePairingHandshakeServerTest(nlTestSuite * inSuite, void * inContext); + static void ClientReceivesBusyTest(nlTestSuite * inSuite, void * inContext); static void Sigma1ParsingTest(nlTestSuite * inSuite, void * inContext); static void DestinationIdTest(nlTestSuite * inSuite, void * inContext); static void SessionResumptionStorage(nlTestSuite * inSuite, void * inContext); @@ -536,6 +545,58 @@ void TestCASESession::SecurePairingHandshakeServerTest(nlTestSuite * inSuite, vo chip::Platform::Delete(pairingCommissioner); chip::Platform::Delete(pairingCommissioner1); + + gPairingServer.Shutdown(); +} + +void TestCASESession::ClientReceivesBusyTest(nlTestSuite * inSuite, void * inContext) +{ + TestContext & ctx = *reinterpret_cast(inContext); + TemporarySessionManager sessionManager(inSuite, ctx); + + TestCASESecurePairingDelegate delegateCommissioner1, delegateCommissioner2; + CASESession pairingCommissioner1, pairingCommissioner2; + + pairingCommissioner1.SetGroupDataProvider(&gCommissionerGroupDataProvider); + pairingCommissioner2.SetGroupDataProvider(&gCommissionerGroupDataProvider); + + auto & loopback = ctx.GetLoopback(); + loopback.mSentMessageCount = 0; + + NL_TEST_ASSERT(inSuite, + gPairingServer.ListenForSessionEstablishment(&ctx.GetExchangeManager(), &ctx.GetSecureSessionManager(), + &gDeviceFabrics, nullptr, nullptr, + &gDeviceGroupDataProvider) == CHIP_NO_ERROR); + + ExchangeContext * contextCommissioner1 = ctx.NewUnauthenticatedExchangeToBob(&pairingCommissioner1); + ExchangeContext * contextCommissioner2 = ctx.NewUnauthenticatedExchangeToBob(&pairingCommissioner2); + + NL_TEST_ASSERT(inSuite, + pairingCommissioner1.EstablishSession(sessionManager, &gCommissionerFabrics, + ScopedNodeId{ Node01_01, gCommissionerFabricIndex }, contextCommissioner1, + nullptr, nullptr, &delegateCommissioner1, NullOptional) == CHIP_NO_ERROR); + NL_TEST_ASSERT(inSuite, + pairingCommissioner2.EstablishSession(sessionManager, &gCommissionerFabrics, + ScopedNodeId{ Node01_01, gCommissionerFabricIndex }, contextCommissioner2, + nullptr, nullptr, &delegateCommissioner2, NullOptional) == CHIP_NO_ERROR); + + ServiceEvents(ctx); + + // We should have one full handshake and one Sigma1 + Busy + ack. If that + // ever changes (e.g. because our server starts supporting multiple parallel + // handshakes), this test needs to be fixed so that the server is still + // responding BUSY to the client. + NL_TEST_ASSERT(inSuite, loopback.mSentMessageCount == sTestCaseMessageCount + 3); + NL_TEST_ASSERT(inSuite, delegateCommissioner1.mNumPairingComplete == 1); + NL_TEST_ASSERT(inSuite, delegateCommissioner2.mNumPairingComplete == 0); + + NL_TEST_ASSERT(inSuite, delegateCommissioner1.mNumPairingErrors == 0); + NL_TEST_ASSERT(inSuite, delegateCommissioner2.mNumPairingErrors == 1); + + NL_TEST_ASSERT(inSuite, delegateCommissioner1.mNumBusyResponses == 0); + NL_TEST_ASSERT(inSuite, delegateCommissioner2.mNumBusyResponses == 1); + + gPairingServer.Shutdown(); } struct Sigma1Params @@ -1115,6 +1176,7 @@ static const nlTest sTests[] = NL_TEST_DEF("Start", chip::TestCASESession::SecurePairingStartTest), NL_TEST_DEF("Handshake", chip::TestCASESession::SecurePairingHandshakeTest), NL_TEST_DEF("ServerHandshake", chip::TestCASESession::SecurePairingHandshakeServerTest), + NL_TEST_DEF("ClientReceivesBusy", chip::TestCASESession::ClientReceivesBusyTest), NL_TEST_DEF("Sigma1Parsing", chip::TestCASESession::Sigma1ParsingTest), NL_TEST_DEF("DestinationId", chip::TestCASESession::DestinationIdTest), NL_TEST_DEF("SessionResumptionStorage", chip::TestCASESession::SessionResumptionStorage),