From 6446e00b398e08d185ceb5ff51cac7e717b852c9 Mon Sep 17 00:00:00 2001
From: Michael Sandstedt <michael.sandstedt@smartthings.com>
Date: Mon, 4 Jul 2022 20:31:47 -0500
Subject: [PATCH] Fix CATValues == operator (#20253)

* Fix CATValues == operator

CATValues == was using std::array equality, but this considers order
of the data, which is not relevant for CATs.  Fix this by sorting first
and then comparing.

While we're at it, add != and < operators.  The latter is useful for
sorting (e.g. for maps and sets).

Fixes #20252

Testing: added unit tests for the new and corrected operators.

* Update src/lib/core/tests/TestCATValues.cpp

Co-authored-by: Tennessee Carmel-Veilleux <tennessee.carmelveilleux@gmail.com>

* Update src/lib/core/tests/TestCATValues.cpp

Co-authored-by: Tennessee Carmel-Veilleux <tennessee.carmelveilleux@gmail.com>

* Update src/lib/core/tests/TestCATValues.cpp

Co-authored-by: Tennessee Carmel-Veilleux <tennessee.carmelveilleux@gmail.com>

* Fix comparison algorithm to properly handled kUndefinedCAT and repeated values

* CATValues are more like a set: repeated values aren't relevant
* kUndefinedCAT especially must be ignored in comparison

Refactor the == operator to more precisely reflect the sementics of CATs
given the above.  This removes the need to instantiate temporary stack
objcts for sort; simple sorting won't work anyway.

* Fixes to CASEAuthTag for comparison

- Adds basic set operation of `Contains`, `ContainsIdentifier`
  and `GetNumTagsPresent`
- Add missing unit tests for existing operations

* restyle

* Update src/lib/core/CASEAuthTag.h

Co-authored-by: Tennessee Carmel-Veilleux <tennessee.carmelveilleux@gmail.com>
---
 src/lib/core/CASEAuthTag.h           |  94 ++++++++++++-
 src/lib/core/tests/BUILD.gn          |   1 +
 src/lib/core/tests/TestCATValues.cpp | 199 +++++++++++++++++++++++++++
 3 files changed, 289 insertions(+), 5 deletions(-)
 create mode 100644 src/lib/core/tests/TestCATValues.cpp

diff --git a/src/lib/core/CASEAuthTag.h b/src/lib/core/CASEAuthTag.h
index 4e5a7b05150bac..b67e86709cb0b4 100644
--- a/src/lib/core/CASEAuthTag.h
+++ b/src/lib/core/CASEAuthTag.h
@@ -28,9 +28,10 @@ namespace chip {
 
 typedef uint32_t CASEAuthTag;
 
-static constexpr CASEAuthTag kUndefinedCAT = 0;
-static constexpr NodeId kTagIdentifierMask = 0x0000'0000'FFFF'0000ULL;
-static constexpr NodeId kTagVersionMask    = 0x0000'0000'0000'FFFFULL;
+static constexpr CASEAuthTag kUndefinedCAT    = 0;
+static constexpr NodeId kTagIdentifierMask    = 0x0000'0000'FFFF'0000ULL;
+static constexpr uint32_t kTagIdentifierShift = 16;
+static constexpr NodeId kTagVersionMask       = 0x0000'0000'0000'FFFFULL;
 
 // Maximum number of CASE Authenticated Tags (CAT) in the CHIP certificate subject.
 static constexpr size_t kMaxSubjectCATAttributeCount = CHIP_CONFIG_CERT_MAX_RDN_ATTRIBUTES - 2;
@@ -39,14 +40,65 @@ struct CATValues
 {
     std::array<CASEAuthTag, kMaxSubjectCATAttributeCount> values = { kUndefinedCAT };
 
-    /* @brief Returns size of the CAT values array.
+    /* @brief Returns maximum number of CAT values that the array can contain.
      */
     static constexpr size_t size() { return std::tuple_size<decltype(values)>::value; }
 
+    /**
+     * @return the number of CATs present in the set (values not equal to kUndefinedCAT)
+     */
+    size_t GetNumTagsPresent() const
+    {
+        size_t count = 0;
+        for (auto cat : values)
+        {
+            count += (cat != kUndefinedCAT) ? 1 : 0;
+        }
+        return count;
+    }
+
+    /**
+     * @return true if `tag` is in the set exactly, false otherwise.
+     */
+    bool Contains(CASEAuthTag tag) const
+    {
+        for (auto candidate : values)
+        {
+            if ((candidate != kUndefinedCAT) && (candidate == tag))
+            {
+                return true;
+            }
+        }
+
+        return false;
+    }
+
+    /**
+     * @brief Returns true if this set contains any version of the `identifier`
+     *
+     * @param identifier - CAT identifier to find
+     * @return true if the identifier is in the set, false otherwise
+     */
+    bool ContainsIdentifier(uint16_t identifier) const
+    {
+        for (auto candidate : values)
+        {
+            uint16_t candidate_identifier = static_cast<uint16_t>((candidate & kTagIdentifierMask) >> kTagIdentifierShift);
+            if ((candidate != kUndefinedCAT) && (identifier == candidate_identifier))
+            {
+                return true;
+            }
+        }
+
+        return false;
+    }
+
     /* @brief Returns true if subject input checks against one of the CATs in the values array.
      */
     bool CheckSubjectAgainstCATs(NodeId subject) const
     {
+        VerifyOrReturnError(IsCASEAuthTag(subject), false);
+
         for (auto cat : values)
         {
             // All valid CAT values are always in the beginning of the array followed by kUndefinedCAT values.
@@ -60,7 +112,29 @@ struct CATValues
         return false;
     }
 
-    bool operator==(const CATValues & that) const { return values == that.values; }
+    bool operator==(const CATValues & other) const
+    {
+        // Two sets of CATs confer equal permissions if the sets are exactly equal.
+        // Ignoring kUndefinedCAT values, evaluate this.
+        if (this->GetNumTagsPresent() != other.GetNumTagsPresent())
+        {
+            return false;
+        }
+        for (auto cat : this->values)
+        {
+            if (cat == kUndefinedCAT)
+            {
+                continue;
+            }
+
+            if (!other.Contains(cat))
+            {
+                return false;
+            }
+        }
+        return true;
+    }
+    bool operator!=(const CATValues & other) const { return !(*this == other); }
 
     static constexpr size_t kSerializedLength = kMaxSubjectCATAttributeCount * sizeof(CASEAuthTag);
     typedef uint8_t Serialized[kSerializedLength];
@@ -103,4 +177,14 @@ constexpr bool IsValidCASEAuthTag(CASEAuthTag aCAT)
     return (aCAT & kTagVersionMask) > 0;
 }
 
+constexpr uint16_t GetCASEAuthTagIdentifier(CASEAuthTag aCAT)
+{
+    return static_cast<uint16_t>((aCAT & kTagIdentifierMask) >> kTagIdentifierShift);
+}
+
+constexpr uint16_t GetCASEAuthTagVersion(CASEAuthTag aCAT)
+{
+    return static_cast<uint16_t>(aCAT & kTagVersionMask);
+}
+
 } // namespace chip
diff --git a/src/lib/core/tests/BUILD.gn b/src/lib/core/tests/BUILD.gn
index f6fc3b46804262..86120cd3a5ae53 100644
--- a/src/lib/core/tests/BUILD.gn
+++ b/src/lib/core/tests/BUILD.gn
@@ -22,6 +22,7 @@ chip_test_suite("tests") {
   output_name = "libCoreTests"
 
   test_sources = [
+    "TestCATValues.cpp",
     "TestCHIPCallback.cpp",
     "TestCHIPErrorStr.cpp",
     "TestCHIPTLV.cpp",
diff --git a/src/lib/core/tests/TestCATValues.cpp b/src/lib/core/tests/TestCATValues.cpp
new file mode 100644
index 00000000000000..95ab813f89e638
--- /dev/null
+++ b/src/lib/core/tests/TestCATValues.cpp
@@ -0,0 +1,199 @@
+/*
+ *
+ *    Copyright (c) 2022 Project CHIP Authors
+ *    All rights reserved.
+ *
+ *    Licensed under the Apache License, Version 2.0 (the "License");
+ *    you may not use this file except in compliance with the License.
+ *    You may obtain a copy of the License at
+ *
+ *        http://www.apache.org/licenses/LICENSE-2.0
+ *
+ *    Unless required by applicable law or agreed to in writing, software
+ *    distributed under the License is distributed on an "AS IS" BASIS,
+ *    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ *    See the License for the specific language governing permissions and
+ *    limitations under the License.
+ */
+
+#include <lib/support/UnitTestRegistration.h>
+#include <nlunit-test.h>
+
+#include <lib/core/CASEAuthTag.h>
+
+using namespace chip;
+
+void TestEqualityOperator(nlTestSuite * inSuite, void * inContext)
+{
+    {
+        auto a                 = CATValues{ { 0x1111'0001, 0x2222'0002, 0x3333'0003 } };
+        auto b                 = CATValues{ { 0x1111'0001, 0x3333'0003, 0x2222'0002 } };
+        auto c                 = CATValues{ { 0x2222'0002, 0x1111'0001, 0x3333'0003 } };
+        auto d                 = CATValues{ { 0x2222'0002, 0x3333'0003, 0x1111'0001 } };
+        auto e                 = CATValues{ { 0x3333'0003, 0x1111'0001, 0x2222'0002 } };
+        auto f                 = CATValues{ { 0x3333'0003, 0x2222'0002, 0x1111'0001 } };
+        CATValues candidates[] = { a, b, c, d, e, f };
+        for (auto & outer : candidates)
+        {
+            for (auto & inner : candidates)
+            {
+                NL_TEST_ASSERT(inSuite, inner == outer);
+            }
+        }
+    }
+    {
+        auto a                 = CATValues{ {} };
+        auto b                 = CATValues{ {} };
+        CATValues candidates[] = { a, b };
+        for (auto & outer : candidates)
+        {
+            for (auto & inner : candidates)
+            {
+                NL_TEST_ASSERT(inSuite, inner == outer);
+            }
+        }
+    }
+}
+
+void TestInequalityOperator(nlTestSuite * inSuite, void * inContext)
+{
+    auto a                 = CATValues{ { 0x1111'0001 } };
+    auto b                 = CATValues{ { 0x1111'0001, 0x2222'0002 } };
+    auto c                 = CATValues{ { 0x1111'0001, 0x2222'0002, 0x3333'0003 } };
+    auto d                 = CATValues{ { 0x2222'0002 } };
+    auto e                 = CATValues{ { 0x2222'0002, 0x3333'0003 } };
+    auto f                 = CATValues{ { 0x2222'0002, 0x3333'0003, 0x4444'0004 } };
+    auto g                 = CATValues{ { 0x3333'0003 } };
+    auto h                 = CATValues{ { 0x3333'0003, 0x4444'0004 } };
+    auto i                 = CATValues{ { 0x3333'0003, 0x4444'0004, 0x5555'0005 } };
+    auto j                 = CATValues{ { 0x4444'0004 } };
+    auto k                 = CATValues{ { 0x4444'0004, 0x5555'0005 } };
+    auto l                 = CATValues{ { 0x4444'0004, 0x5555'0005, 0x6666'0006 } };
+    auto m                 = CATValues{ { 0x5555'0005 } };
+    auto n                 = CATValues{ { 0x5555'0005, 0x6666'0006 } };
+    auto o                 = CATValues{ { 0x5555'0005, 0x6666'0006, 0x7777'0007 } };
+    auto p                 = CATValues{ { 0x6666'0006 } };
+    auto q                 = CATValues{ { 0x6666'0006, 0x7777'0007 } };
+    auto r                 = CATValues{ { 0x6666'0006, 0x7777'0007, 0x8888'0008 } };
+    auto s                 = CATValues{ { 0x7777'0007 } };
+    auto t                 = CATValues{ { 0x7777'0007, 0x8888'0008 } };
+    auto u                 = CATValues{ { 0x7777'0007, 0x8888'0008, 0x9999'0009 } };
+    CATValues candidates[] = { a, b, c, d, e, f, g, h, i, j, k, l, m, n, o, p, q, r, s, t, u };
+    for (auto & outer : candidates)
+    {
+        for (auto & inner : candidates)
+        {
+            if (&inner == &outer)
+            {
+                continue;
+            }
+            NL_TEST_ASSERT(inSuite, inner != outer);
+        }
+    }
+}
+
+void TestMembership(nlTestSuite * inSuite, void * inContext)
+{
+    auto a = CATValues{ { 0x1111'0001 } };
+    auto b = CATValues{ { 0x1111'0001, 0x2222'0002 } };
+    auto c = CATValues{ { 0x1111'0001, 0x2222'0002, 0x3333'0003 } };
+
+    NL_TEST_ASSERT(inSuite, a.Contains(0x1111'0001));
+    NL_TEST_ASSERT(inSuite, a.GetNumTagsPresent() == 1);
+    NL_TEST_ASSERT(inSuite, !a.Contains(0x1111'0002));
+    NL_TEST_ASSERT(inSuite, !a.Contains(0x2222'0002));
+    NL_TEST_ASSERT(inSuite, a.ContainsIdentifier(0x1111));
+    NL_TEST_ASSERT(inSuite, !a.ContainsIdentifier(0x2222));
+
+    NL_TEST_ASSERT(inSuite, b.Contains(0x1111'0001));
+    NL_TEST_ASSERT(inSuite, b.Contains(0x2222'0002));
+    NL_TEST_ASSERT(inSuite, b.GetNumTagsPresent() == 2);
+    NL_TEST_ASSERT(inSuite, b.ContainsIdentifier(0x1111));
+    NL_TEST_ASSERT(inSuite, b.ContainsIdentifier(0x2222));
+
+    NL_TEST_ASSERT(inSuite, c.Contains(0x1111'0001));
+    NL_TEST_ASSERT(inSuite, c.Contains(0x2222'0002));
+    NL_TEST_ASSERT(inSuite, c.Contains(0x3333'0003));
+    NL_TEST_ASSERT(inSuite, c.GetNumTagsPresent() == 3);
+    NL_TEST_ASSERT(inSuite, c.ContainsIdentifier(0x1111));
+    NL_TEST_ASSERT(inSuite, c.ContainsIdentifier(0x2222));
+    NL_TEST_ASSERT(inSuite, c.ContainsIdentifier(0x3333));
+}
+
+void TestSubjectMatching(nlTestSuite * inSuite, void * inContext)
+{
+    // Check operational node IDs don't match
+    auto a = CATValues{ { 0x2222'0002 } };
+    NL_TEST_ASSERT(inSuite, !a.CheckSubjectAgainstCATs(static_cast<chip::NodeId>(0x0001'0002'0003'0004ull)));
+    NL_TEST_ASSERT(inSuite, !a.CheckSubjectAgainstCATs(static_cast<chip::NodeId>(0x0001'0002'2222'0002ull)));
+
+    auto b = CATValues{ { 0x1111'0001 } };
+    NL_TEST_ASSERT(inSuite, b.CheckSubjectAgainstCATs(static_cast<chip::NodeId>(0xFFFF'FFFD'1111'0001ull)));
+    NL_TEST_ASSERT(inSuite, !b.CheckSubjectAgainstCATs(static_cast<chip::NodeId>(0xFFFF'FFFD'1111'0002ull)));
+
+    auto c = CATValues{ { 0x1111'0001, 0x2222'0002 } };
+    NL_TEST_ASSERT(inSuite, c.CheckSubjectAgainstCATs(static_cast<chip::NodeId>(0xFFFF'FFFD'2222'0001ull)));
+    NL_TEST_ASSERT(inSuite, c.CheckSubjectAgainstCATs(static_cast<chip::NodeId>(0xFFFF'FFFD'2222'0002ull)));
+    NL_TEST_ASSERT(inSuite, !c.CheckSubjectAgainstCATs(static_cast<chip::NodeId>(0xFFFF'FFFD'2222'0003ull)));
+
+    auto d = CATValues{ { 0x1111'0001, 0x2222'0002, 0x3333'0003 } };
+    NL_TEST_ASSERT(inSuite, d.CheckSubjectAgainstCATs(static_cast<chip::NodeId>(0xFFFF'FFFD'3333'0001ull)));
+    NL_TEST_ASSERT(inSuite, d.CheckSubjectAgainstCATs(static_cast<chip::NodeId>(0xFFFF'FFFD'3333'0002ull)));
+    NL_TEST_ASSERT(inSuite, d.CheckSubjectAgainstCATs(static_cast<chip::NodeId>(0xFFFF'FFFD'3333'0003ull)));
+    NL_TEST_ASSERT(inSuite, !d.CheckSubjectAgainstCATs(static_cast<chip::NodeId>(0xFFFF'FFFD'3333'0004ull)));
+    NL_TEST_ASSERT(inSuite, !d.CheckSubjectAgainstCATs(static_cast<chip::NodeId>(0xFFFF'FFFD'3333'ffffull)));
+
+    auto e = CATValues{ { 0x1111'0001, 0x2222'0002, 0x3333'ffff } };
+    NL_TEST_ASSERT(inSuite, e.CheckSubjectAgainstCATs(static_cast<chip::NodeId>(0xFFFF'FFFD'3333'0001ull)));
+    NL_TEST_ASSERT(inSuite, e.CheckSubjectAgainstCATs(static_cast<chip::NodeId>(0xFFFF'FFFD'3333'0002ull)));
+    NL_TEST_ASSERT(inSuite, e.CheckSubjectAgainstCATs(static_cast<chip::NodeId>(0xFFFF'FFFD'3333'0003ull)));
+    NL_TEST_ASSERT(inSuite, e.CheckSubjectAgainstCATs(static_cast<chip::NodeId>(0xFFFF'FFFD'3333'0004ull)));
+    NL_TEST_ASSERT(inSuite, e.CheckSubjectAgainstCATs(static_cast<chip::NodeId>(0xFFFF'FFFD'3333'ffffull)));
+}
+// Test Suite
+
+/**
+ *  Test Suite that lists all the test functions.
+ */
+// clang-format off
+static const nlTest sTests[] =
+{
+    NL_TEST_DEF("Equality operator", TestEqualityOperator),
+    NL_TEST_DEF("Inequality operator", TestInequalityOperator),
+    NL_TEST_DEF("Set operations", TestMembership),
+    NL_TEST_DEF("Subject matching for ACL", TestSubjectMatching),
+    NL_TEST_SENTINEL()
+};
+// clang-format on
+
+int TestCATValues_Setup(void * inContext)
+{
+    return SUCCESS;
+}
+
+/**
+ *  Tear down the test suite.
+ */
+int TestCATValues_Teardown(void * inContext)
+{
+    return SUCCESS;
+}
+
+int TestCATValues(void)
+{
+    // clang-format off
+    nlTestSuite theSuite =
+    {
+        "CATValues",
+        &sTests[0],
+        TestCATValues_Setup,
+        TestCATValues_Teardown,
+    };
+    // clang-format on
+
+    nlTestRunner(&theSuite, nullptr);
+
+    return (nlTestRunnerStats(&theSuite));
+}
+
+CHIP_REGISTER_TEST_SUITE(TestCATValues)