diff --git a/src/credentials/CertificationDeclaration.cpp b/src/credentials/CertificationDeclaration.cpp index 08149bf6c44d26..42bd9656b30750 100644 --- a/src/credentials/CertificationDeclaration.cpp +++ b/src/credentials/CertificationDeclaration.cpp @@ -336,15 +336,10 @@ CHIP_ERROR CertificationElementsDecoder::FindAndEnterArray(const ByteSpan & enco ReturnErrorOnFailure(mReader.EnterContainer(outerContainerType1)); // position to arrayTag Array - CHIP_ERROR error = CHIP_NO_ERROR; do { - error = mReader.Next(kTLVType_Array, arrayTag); - // Return error code unless one of three things happened: - // 1. We found the right thing (CHIP_NO_ERROR returned). - // 2. The next tag is not the one we are looking for (CHIP_ERROR_UNEXPECTED_TLV_ELEMENT). - VerifyOrReturnError(error == CHIP_NO_ERROR || error == CHIP_ERROR_UNEXPECTED_TLV_ELEMENT, error); - } while (error != CHIP_NO_ERROR); + ReturnErrorOnFailure(mReader.Next()); + } while (mReader.Expect(kTLVType_Array, arrayTag) != CHIP_NO_ERROR); ReturnErrorOnFailure(mReader.EnterContainer(outerContainerType2)); diff --git a/src/lib/core/TLVReader.cpp b/src/lib/core/TLVReader.cpp index 1fe3b6b4c695f0..7942b7f803bec9 100644 --- a/src/lib/core/TLVReader.cpp +++ b/src/lib/core/TLVReader.cpp @@ -585,7 +585,8 @@ CHIP_ERROR TLVReader::Next() CHIP_ERROR TLVReader::Expect(Tag expectedTag) { - VerifyOrReturnError(mElemTag == expectedTag, CHIP_ERROR_UNEXPECTED_TLV_ELEMENT); + VerifyOrReturnError(GetType() != kTLVType_NotSpecified, CHIP_ERROR_WRONG_TLV_TYPE); + VerifyOrReturnError(GetTag() == expectedTag, CHIP_ERROR_UNEXPECTED_TLV_ELEMENT); return CHIP_NO_ERROR; } @@ -598,8 +599,8 @@ CHIP_ERROR TLVReader::Next(Tag expectedTag) CHIP_ERROR TLVReader::Expect(TLVType expectedType, Tag expectedTag) { - ReturnErrorOnFailure(Expect(expectedTag)); VerifyOrReturnError(GetType() == expectedType, CHIP_ERROR_WRONG_TLV_TYPE); + VerifyOrReturnError(GetTag() == expectedTag, CHIP_ERROR_UNEXPECTED_TLV_ELEMENT); return CHIP_NO_ERROR; } diff --git a/src/lib/core/TLVReader.h b/src/lib/core/TLVReader.h index 6be621e6382ce6..744402a8b059cd 100644 --- a/src/lib/core/TLVReader.h +++ b/src/lib/core/TLVReader.h @@ -167,7 +167,10 @@ class DLL_EXPORT TLVReader * Advances the TLVReader object to the next TLV element to be read, asserting the tag of * the new element. * - * This is a convenience method that combines the behavior of Next() and Expect(). + * This is a convenience method that combines the behavior of Next() and Expect(...). + * + * Note that if this method returns an error, the reader may or may not have been advanced already. + * In use cases where this is important, separate calls to Next() and Expect(...) should be made. * * @retval #CHIP_NO_ERROR If the reader was successfully positioned on a new element * matching the expected parameters. @@ -179,6 +182,7 @@ class DLL_EXPORT TLVReader * Checks that the TLV reader is positioned at an element with the expected tag. * * @retval #CHIP_NO_ERROR If the reader is positioned on the expected element. + * @retval #CHIP_ERROR_WRONG_TLV_TYPE If the reader is not positioned on an element. * @retval #CHIP_ERROR_UNEXPECTED_TLV_ELEMENT * If the tag associated with the new element does not match the * value of the @p expectedTag argument. @@ -189,7 +193,10 @@ class DLL_EXPORT TLVReader * Advances the TLVReader object to the next TLV element to be read, asserting the type and tag of * the new element. * - * This is a convenience method that combines the behavior of Next() and Expect(). + * This is a convenience method that combines the behavior of Next() and Expect(...). + * + * Note that if this method returns an error, the reader may or may not have been advanced already. + * In use cases where this is important, separate calls to Next() and Expect(...) should be made. * * @retval #CHIP_NO_ERROR If the reader was successfully positioned on a new element * matching the expected parameters. diff --git a/src/lib/core/tests/TestTLV.cpp b/src/lib/core/tests/TestTLV.cpp index 0ea46d463bcc3e..529da85a247fc5 100644 --- a/src/lib/core/tests/TestTLV.cpp +++ b/src/lib/core/tests/TestTLV.cpp @@ -3628,6 +3628,59 @@ void TestTLVReaderErrorHandling(nlTestSuite * inSuite) chip::Platform::MemoryFree(const_cast(data)); } +void TestTLVReaderExpect(nlTestSuite * inSuite) +{ + // Prepare some test data + uint8_t buffer[20]; + TLVWriter writer; + writer.Init(buffer, sizeof(buffer)); + TLVType outerContainer; + NL_TEST_ASSERT_SUCCESS(inSuite, writer.StartContainer(AnonymousTag(), kTLVType_Structure, outerContainer)); + NL_TEST_ASSERT_SUCCESS(inSuite, writer.PutBoolean(ContextTag(23), false)); + NL_TEST_ASSERT_SUCCESS(inSuite, writer.EndContainer(outerContainer)); + + TLVReader reader; + reader.Init(buffer, writer.GetLengthWritten()); + + // Positioned before the first element + NL_TEST_ASSERT(inSuite, reader.GetType() == kTLVType_NotSpecified); + + NL_TEST_ASSERT(inSuite, reader.Expect(AnonymousTag()) == CHIP_ERROR_WRONG_TLV_TYPE); + NL_TEST_ASSERT(inSuite, reader.Expect(ContextTag(23)) == CHIP_ERROR_WRONG_TLV_TYPE); + NL_TEST_ASSERT(inSuite, reader.Expect(kTLVType_Boolean, AnonymousTag()) == CHIP_ERROR_WRONG_TLV_TYPE); + + // Positioned on kTLVType_Structure / AnonymousTag(), + NL_TEST_ASSERT_SUCCESS(inSuite, reader.Next()); + NL_TEST_ASSERT(inSuite, reader.GetType() == kTLVType_Structure); + NL_TEST_ASSERT(inSuite, reader.GetTag() == AnonymousTag()); + + NL_TEST_ASSERT(inSuite, reader.Expect(ContextTag(23)) == CHIP_ERROR_UNEXPECTED_TLV_ELEMENT); + NL_TEST_ASSERT_SUCCESS(inSuite, reader.Expect(AnonymousTag())); + + NL_TEST_ASSERT(inSuite, reader.Expect(kTLVType_SignedInteger, AnonymousTag()) == CHIP_ERROR_WRONG_TLV_TYPE); + NL_TEST_ASSERT_SUCCESS(inSuite, reader.Expect(kTLVType_Structure, AnonymousTag())); + + // Positioned before first struct element + NL_TEST_ASSERT_SUCCESS(inSuite, reader.EnterContainer(outerContainer)); + NL_TEST_ASSERT(inSuite, reader.GetType() == kTLVType_NotSpecified); + + NL_TEST_ASSERT(inSuite, reader.Expect(AnonymousTag()) == CHIP_ERROR_WRONG_TLV_TYPE); + NL_TEST_ASSERT(inSuite, reader.Expect(ContextTag(23)) == CHIP_ERROR_WRONG_TLV_TYPE); + NL_TEST_ASSERT(inSuite, reader.Expect(kTLVType_Boolean, AnonymousTag()) == CHIP_ERROR_WRONG_TLV_TYPE); + + // Positioned on kTLVType_Boolean / ContextTag(23) + NL_TEST_ASSERT_SUCCESS(inSuite, reader.Next()); + NL_TEST_ASSERT(inSuite, reader.GetType() == kTLVType_Boolean); + NL_TEST_ASSERT(inSuite, reader.GetTag() == ContextTag(23)); + + NL_TEST_ASSERT(inSuite, reader.Expect(AnonymousTag()) == CHIP_ERROR_UNEXPECTED_TLV_ELEMENT); + NL_TEST_ASSERT(inSuite, reader.Expect(ContextTag(42)) == CHIP_ERROR_UNEXPECTED_TLV_ELEMENT); + NL_TEST_ASSERT_SUCCESS(inSuite, reader.Expect(ContextTag(23))); + + NL_TEST_ASSERT(inSuite, reader.Expect(kTLVType_SignedInteger, ContextTag(23)) == CHIP_ERROR_WRONG_TLV_TYPE); + NL_TEST_ASSERT_SUCCESS(inSuite, reader.Expect(kTLVType_Boolean, ContextTag(23))); +} + /** * Test that CHIP TLV reader returns an error when a read is requested that * would truncate the output. @@ -3789,6 +3842,8 @@ void CheckTLVReader(nlTestSuite * inSuite, void * inContext) TestTLVReaderErrorHandling(inSuite); + TestTLVReaderExpect(inSuite); + TestTLVReaderTruncatedReads(inSuite); TestTLVReaderInPractice(inSuite);