From 25bc83d01ad9176e8e0c50d61cf8920c10c52f6e Mon Sep 17 00:00:00 2001 From: Justin King Date: Fri, 16 Aug 2024 13:39:22 -0700 Subject: [PATCH] Add `EnumType` which covers the enum proposal, which is type only PiperOrigin-RevId: 663846915 --- common/kind.cc | 2 + common/kind.h | 1 + common/type.cc | 33 ++++++-- common/type.h | 24 ++++++ common/type_kind.h | 1 + common/type_test.cc | 65 +++++++++++++++- common/types/enum_type.cc | 43 +++++++++++ common/types/enum_type.h | 137 +++++++++++++++++++++++++++++++++ common/types/enum_type_test.cc | 66 ++++++++++++++++ common/types/types.h | 31 ++++---- 10 files changed, 379 insertions(+), 24 deletions(-) create mode 100644 common/types/enum_type.cc create mode 100644 common/types/enum_type.h create mode 100644 common/types/enum_type_test.cc diff --git a/common/kind.cc b/common/kind.cc index 364d9ca9c..21fb9e9f3 100644 --- a/common/kind.cc +++ b/common/kind.cc @@ -70,6 +70,8 @@ absl::string_view KindToString(Kind kind) { return "google.protobuf.StringValue"; case Kind::kBytesWrapper: return "google.protobuf.BytesValue"; + case Kind::kEnum: + return "enum"; default: return "*error*"; } diff --git a/common/kind.h b/common/kind.h index 55a9b805c..60a1e10b9 100644 --- a/common/kind.h +++ b/common/kind.h @@ -52,6 +52,7 @@ enum class Kind /* : uint8_t */ { kTypeParam, kFunction, + kEnum, // Legacy aliases, deprecated do not use. kNullType = kNull, diff --git a/common/type.cc b/common/type.cc index adb7d4e84..1d92a0ec7 100644 --- a/common/type.cc +++ b/common/type.cc @@ -69,19 +69,27 @@ Type Type::Message(absl::Nonnull descriptor) { return MessageType(descriptor); } } + +Type Type::Enum(absl::Nonnull descriptor) { + if (descriptor->full_name() == "google.protobuf.NullValue") { + return NullType(); + } + return EnumType(descriptor); +} + namespace { -static constexpr std::array kTypeToKindArray = { +static constexpr std::array kTypeToKindArray = { TypeKind::kError, TypeKind::kAny, TypeKind::kBool, TypeKind::kBoolWrapper, TypeKind::kBytes, TypeKind::kBytesWrapper, TypeKind::kDouble, TypeKind::kDoubleWrapper, TypeKind::kDuration, - TypeKind::kDyn, TypeKind::kError, TypeKind::kFunction, - TypeKind::kInt, TypeKind::kIntWrapper, TypeKind::kList, - TypeKind::kMap, TypeKind::kNull, TypeKind::kOpaque, - TypeKind::kString, TypeKind::kStringWrapper, TypeKind::kStruct, - TypeKind::kStruct, TypeKind::kTimestamp, TypeKind::kTypeParam, - TypeKind::kType, TypeKind::kUint, TypeKind::kUintWrapper, - TypeKind::kUnknown}; + TypeKind::kDyn, TypeKind::kEnum, TypeKind::kError, + TypeKind::kFunction, TypeKind::kInt, TypeKind::kIntWrapper, + TypeKind::kList, TypeKind::kMap, TypeKind::kNull, + TypeKind::kOpaque, TypeKind::kString, TypeKind::kStringWrapper, + TypeKind::kStruct, TypeKind::kStruct, TypeKind::kTimestamp, + TypeKind::kTypeParam, TypeKind::kType, TypeKind::kUint, + TypeKind::kUintWrapper, TypeKind::kUnknown}; static_assert(kTypeToKindArray.size() == absl::variant_size(), @@ -222,6 +230,10 @@ absl::optional Type::AsDyn() const { return GetOrNullopt(variant_); } +absl::optional Type::AsEnum() const { + return GetOrNullopt(variant_); +} + absl::optional Type::AsError() const { return GetOrNullopt(variant_); } @@ -363,6 +375,11 @@ Type::operator DynType() const { return GetOrDie(variant_); } +Type::operator EnumType() const { + ABSL_DCHECK(IsEnum()) << DebugString(); + return GetOrDie(variant_); +} + Type::operator ErrorType() const { ABSL_DCHECK(IsError()) << DebugString(); return GetOrDie(variant_); diff --git a/common/type.h b/common/type.h index 56a3aa1bb..91b4cfe96 100644 --- a/common/type.h +++ b/common/type.h @@ -43,6 +43,7 @@ #include "common/types/double_wrapper_type.h" // IWYU pragma: export #include "common/types/duration_type.h" // IWYU pragma: export #include "common/types/dyn_type.h" // IWYU pragma: export +#include "common/types/enum_type.h" // IWYU pragma: export #include "common/types/error_type.h" // IWYU pragma: export #include "common/types/function_type.h" // IWYU pragma: export #include "common/types/int_type.h" // IWYU pragma: export @@ -86,6 +87,12 @@ class Type final { static Type Message(absl::Nonnull descriptor ABSL_ATTRIBUTE_LIFETIME_BOUND); + // Returns an appropriate `Type` for the dynamic protobuf enum. For well + // known enum types, the appropriate `Type` is returned. All others return + // `EnumType`. + static Type Enum(absl::Nonnull descriptor + ABSL_ATTRIBUTE_LIFETIME_BOUND); + Type() = default; Type(const Type&) = default; Type(Type&&) = default; @@ -205,6 +212,8 @@ class Type final { bool IsDyn() const { return absl::holds_alternative(variant_); } + bool IsEnum() const { return absl::holds_alternative(variant_); } + bool IsError() const { return absl::holds_alternative(variant_); } bool IsFunction() const { @@ -317,6 +326,11 @@ class Type final { return IsDyn(); } + template + std::enable_if_t, bool> Is() const { + return IsEnum(); + } + template std::enable_if_t, bool> Is() const { return IsError(); @@ -430,6 +444,8 @@ class Type final { absl::optional AsDyn() const; + absl::optional AsEnum() const; + absl::optional AsError() const; absl::optional AsFunction() const; @@ -534,6 +550,12 @@ class Type final { return AsDyn(); } + template + std::enable_if_t, absl::optional> As() + const { + return AsEnum(); + } + template std::enable_if_t, absl::optional> As() const { @@ -677,6 +699,8 @@ class Type final { explicit operator DynType() const; + explicit operator EnumType() const; + explicit operator ErrorType() const; explicit operator FunctionType() const; diff --git a/common/type_kind.h b/common/type_kind.h index 58e701472..1e9e94df0 100644 --- a/common/type_kind.h +++ b/common/type_kind.h @@ -56,6 +56,7 @@ enum class TypeKind : std::underlying_type_t { kTypeParam = static_cast(Kind::kTypeParam), kFunction = static_cast(Kind::kFunction), + kEnum = static_cast(Kind::kEnum), // Legacy aliases, deprecated do not use. kNullType = kNull, diff --git a/common/type_test.cc b/common/type_test.cc index 35ed9d860..916c23906 100644 --- a/common/type_test.cc +++ b/common/type_test.cc @@ -16,17 +16,32 @@ #include "absl/hash/hash.h" #include "absl/hash/hash_testing.h" -#include "common/native_type.h" +#include "absl/log/die_if_null.h" #include "internal/testing.h" +#include "internal/testing_descriptor_pool.h" #include "google/protobuf/arena.h" namespace cel { namespace { +using ::cel::internal::GetTestingDescriptorPool; using testing::_; using testing::An; using testing::Optional; +TEST(Type, Enum) { + EXPECT_EQ( + Type::Enum( + ABSL_DIE_IF_NULL(GetTestingDescriptorPool()->FindEnumTypeByName( + "google.api.expr.test.v1.proto3.TestAllTypes.NestedEnum"))), + EnumType(ABSL_DIE_IF_NULL(GetTestingDescriptorPool()->FindEnumTypeByName( + "google.api.expr.test.v1.proto3.TestAllTypes.NestedEnum")))); + EXPECT_EQ(Type::Enum( + ABSL_DIE_IF_NULL(GetTestingDescriptorPool()->FindEnumTypeByName( + "google.protobuf.NullValue"))), + NullType()); +} + TEST(Type, KindDebugDeath) { Type type; static_cast(type); @@ -75,6 +90,12 @@ TEST(Type, Is) { EXPECT_TRUE(Type(DynType()).Is()); + EXPECT_TRUE( + Type(EnumType( + ABSL_DIE_IF_NULL(GetTestingDescriptorPool()->FindEnumTypeByName( + "google.api.expr.test.v1.proto3.TestAllTypes.NestedEnum")))) + .Is()); + EXPECT_TRUE(Type(ErrorType()).Is()); EXPECT_TRUE(Type(FunctionType(&arena, DynType(), {})).Is()); @@ -88,6 +109,15 @@ TEST(Type, Is) { EXPECT_TRUE(Type(MapType()).Is()); + EXPECT_TRUE(Type(MessageType(ABSL_DIE_IF_NULL( + GetTestingDescriptorPool()->FindMessageTypeByName( + "google.api.expr.test.v1.proto3.TestAllTypes")))) + .IsStruct()); + EXPECT_TRUE(Type(MessageType(ABSL_DIE_IF_NULL( + GetTestingDescriptorPool()->FindMessageTypeByName( + "google.api.expr.test.v1.proto3.TestAllTypes")))) + .IsMessage()); + EXPECT_TRUE(Type(NullType()).Is()); EXPECT_TRUE(Type(OptionalType()).Is()); @@ -133,6 +163,13 @@ TEST(Type, As) { EXPECT_THAT(Type(DynType()).As(), Optional(An())); + EXPECT_THAT( + Type(EnumType( + ABSL_DIE_IF_NULL(GetTestingDescriptorPool()->FindEnumTypeByName( + "google.api.expr.test.v1.proto3.TestAllTypes.NestedEnum")))) + .As(), + Optional(An())); + EXPECT_THAT(Type(ErrorType()).As(), Optional(An())); EXPECT_TRUE(Type(FunctionType(&arena, DynType(), {})).Is()); @@ -146,6 +183,17 @@ TEST(Type, As) { EXPECT_THAT(Type(MapType()).As(), Optional(An())); + EXPECT_THAT(Type(MessageType(ABSL_DIE_IF_NULL( + GetTestingDescriptorPool()->FindMessageTypeByName( + "google.api.expr.test.v1.proto3.TestAllTypes")))) + .As(), + Optional(An())); + EXPECT_THAT(Type(MessageType(ABSL_DIE_IF_NULL( + GetTestingDescriptorPool()->FindMessageTypeByName( + "google.api.expr.test.v1.proto3.TestAllTypes")))) + .As(), + Optional(An())); + EXPECT_THAT(Type(NullType()).As(), Optional(An())); EXPECT_THAT(Type(OptionalType()).As(), @@ -201,6 +249,12 @@ TEST(Type, Cast) { EXPECT_THAT(static_cast(Type(DynType())), An()); + EXPECT_THAT( + static_cast(Type(EnumType( + ABSL_DIE_IF_NULL(GetTestingDescriptorPool()->FindEnumTypeByName( + "google.api.expr.test.v1.proto3.TestAllTypes.NestedEnum"))))), + An()); + EXPECT_THAT(static_cast(Type(ErrorType())), An()); EXPECT_TRUE(Type(FunctionType(&arena, DynType(), {})).Is()); @@ -216,6 +270,15 @@ TEST(Type, Cast) { EXPECT_THAT(static_cast(Type(MapType())), An()); + EXPECT_THAT(static_cast(Type(MessageType(ABSL_DIE_IF_NULL( + GetTestingDescriptorPool()->FindMessageTypeByName( + "google.api.expr.test.v1.proto3.TestAllTypes"))))), + An()); + EXPECT_THAT(static_cast(Type(MessageType(ABSL_DIE_IF_NULL( + GetTestingDescriptorPool()->FindMessageTypeByName( + "google.api.expr.test.v1.proto3.TestAllTypes"))))), + An()); + EXPECT_THAT(static_cast(Type(NullType())), An()); EXPECT_THAT(static_cast(Type(OptionalType())), diff --git a/common/types/enum_type.cc b/common/types/enum_type.cc new file mode 100644 index 000000000..064105acb --- /dev/null +++ b/common/types/enum_type.cc @@ -0,0 +1,43 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "absl/base/nullability.h" +#include "absl/base/optimization.h" +#include "absl/strings/str_cat.h" +#include "common/type.h" +#include "google/protobuf/descriptor.h" + +namespace cel { + +using google::protobuf::EnumDescriptor; + +bool IsWellKnownEnumType(absl::Nonnull descriptor) { + return descriptor->full_name() == "google.protobuf.NullValue"; +} + +std::string EnumType::DebugString() const { + if (ABSL_PREDICT_TRUE(static_cast(*this))) { + static_assert(sizeof(descriptor_) == 8 || sizeof(descriptor_) == 4, + "sizeof(void*) is neither 8 nor 4"); + return absl::StrCat(name(), "@0x", + absl::Hex(descriptor_, sizeof(descriptor_) == 8 + ? absl::PadSpec::kZeroPad16 + : absl::PadSpec::kZeroPad8)); + } + return std::string(); +} + +} // namespace cel diff --git a/common/types/enum_type.h b/common/types/enum_type.h new file mode 100644 index 000000000..5586e1951 --- /dev/null +++ b/common/types/enum_type.h @@ -0,0 +1,137 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// IWYU pragma: private, include "common/type.h" +// IWYU pragma: friend "common/type.h" + +#ifndef THIRD_PARTY_CEL_CPP_COMMON_TYPES_ENUM_TYPE_H_ +#define THIRD_PARTY_CEL_CPP_COMMON_TYPES_ENUM_TYPE_H_ + +#include +#include +#include +#include +#include + +#include "absl/base/attributes.h" +#include "absl/base/nullability.h" +#include "absl/log/absl_check.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "common/type_kind.h" +#include "google/protobuf/descriptor.h" + +namespace cel { + +class Type; + +bool IsWellKnownEnumType( + absl::Nonnull descriptor); + +class EnumType final { + public: + using element_type = const google::protobuf::EnumDescriptor; + + static constexpr TypeKind kKind = TypeKind::kEnum; + + // Constructs `EnumType` from a pointer to `google::protobuf::EnumDescriptor`. The + // `google::protobuf::EnumDescriptor` must not be one of the well known enum types we + // treat specially, if it is behavior is undefined. If you are unsure, you + // should use `Type::Enum`. + explicit EnumType(absl::Nullable descriptor) + : descriptor_(descriptor) { + ABSL_DCHECK(descriptor == nullptr || !IsWellKnownEnumType(descriptor)) + << descriptor->full_name(); + } + + EnumType() = default; + EnumType(const EnumType&) = default; + EnumType(EnumType&&) = default; + EnumType& operator=(const EnumType&) = default; + EnumType& operator=(EnumType&&) = default; + + static TypeKind kind() { return kKind; } + + absl::string_view name() const ABSL_ATTRIBUTE_LIFETIME_BOUND { + return (*this)->full_name(); + } + + std::string DebugString() const; + + absl::Span parameters() const ABSL_ATTRIBUTE_LIFETIME_BOUND { + ABSL_DCHECK(*this); + return {}; + } + + const google::protobuf::EnumDescriptor& operator*() const { + ABSL_DCHECK(*this); + return *descriptor_; + } + + absl::Nonnull operator->() const { + ABSL_DCHECK(*this); + return descriptor_; + } + + explicit operator bool() const { return descriptor_ != nullptr; } + + friend void swap(EnumType& lhs, EnumType& rhs) noexcept { + using std::swap; + swap(lhs.descriptor_, rhs.descriptor_); + } + + private: + friend struct std::pointer_traits; + + absl::Nullable descriptor_ = nullptr; +}; + +inline bool operator==(EnumType lhs, EnumType rhs) { + return static_cast(lhs) == static_cast(rhs) && + (!static_cast(lhs) || lhs.name() == rhs.name()); +} + +inline bool operator!=(EnumType lhs, EnumType rhs) { + return !operator==(lhs, rhs); +} + +template +H AbslHashValue(H state, EnumType enum_type) { + return H::combine(std::move(state), static_cast(enum_type) + ? enum_type.name() + : absl::string_view()); +} + +inline std::ostream& operator<<(std::ostream& out, EnumType type) { + return out << type.DebugString(); +} + +} // namespace cel + +namespace std { + +template <> +struct pointer_traits { + using pointer = cel::EnumType; + using element_type = typename cel::EnumType::element_type; + using difference_type = ptrdiff_t; + + static element_type* to_address(const pointer& p) noexcept { + return p.descriptor_; + } +}; + +} // namespace std + +#endif // THIRD_PARTY_CEL_CPP_COMMON_TYPES_ENUM_TYPE_H_ diff --git a/common/types/enum_type_test.cc b/common/types/enum_type_test.cc new file mode 100644 index 000000000..d18984b7a --- /dev/null +++ b/common/types/enum_type_test.cc @@ -0,0 +1,66 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "google/protobuf/descriptor.pb.h" +#include "common/memory.h" +#include "common/type.h" +#include "common/type_kind.h" +#include "internal/testing.h" +#include "google/protobuf/descriptor.h" + +namespace cel { +namespace { + +using testing::Eq; +using testing::IsEmpty; +using testing::NotNull; +using testing::StartsWith; + +TEST(EnumType, Kind) { EXPECT_EQ(EnumType::kind(), TypeKind::kEnum); } + +TEST(EnumType, Default) { + EnumType type; + EXPECT_FALSE(type); + EXPECT_THAT(type.DebugString(), Eq("")); + EXPECT_EQ(type, EnumType()); +} + +TEST(EnumType, Descriptor) { + google::protobuf::DescriptorPool pool; + { + google::protobuf::FileDescriptorProto file_desc_proto; + file_desc_proto.set_syntax("proto3"); + file_desc_proto.set_package("test"); + file_desc_proto.set_name("test/enum.proto"); + auto* enum_desc = file_desc_proto.add_enum_type(); + enum_desc->set_name("Enum"); + auto* enum_value_desc = enum_desc->add_value(); + enum_value_desc->set_number(0); + enum_value_desc->set_name("VALUE"); + ASSERT_THAT(pool.BuildFile(file_desc_proto), NotNull()); + } + const google::protobuf::EnumDescriptor* desc = pool.FindEnumTypeByName("test.Enum"); + ASSERT_THAT(desc, NotNull()); + EnumType type(desc); + EXPECT_TRUE(type); + EXPECT_THAT(type.name(), Eq("test.Enum")); + EXPECT_THAT(type.DebugString(), StartsWith("test.Enum@0x")); + EXPECT_THAT(type.parameters(), IsEmpty()); + EXPECT_NE(type, EnumType()); + EXPECT_NE(EnumType(), type); + EXPECT_EQ(cel::to_address(type), desc); +} + +} // namespace +} // namespace cel diff --git a/common/types/types.h b/common/types/types.h index bbe2b733a..dcbe44189 100644 --- a/common/types/types.h +++ b/common/types/types.h @@ -34,6 +34,7 @@ class DoubleType; class DoubleWrapperType; class DurationType; class DynType; +class EnumType; class ErrorType; class FunctionType; class IntType; @@ -65,16 +66,16 @@ struct IsTypeAlternative std::is_same, std::is_same, std::is_same, std::is_same, std::is_same, std::is_same, - std::is_same, std::is_same, - std::is_same, std::is_same, - std::is_same, std::is_same, - std::is_same, std::is_same, - std::is_same, std::is_same, - std::is_same, std::is_same, - std::is_same, std::is_same, - std::is_same, std::is_same, - std::is_same, std::is_same, - std::is_same>> {}; + std::is_same, std::is_same, + std::is_same, std::is_same, + std::is_same, std::is_same, + std::is_same, std::is_same, + std::is_same, std::is_same, + std::is_same, std::is_same, + std::is_same, std::is_same, + std::is_same, std::is_same, + std::is_same, std::is_same, + std::is_same, std::is_same>> {}; template inline constexpr bool IsTypeAlternativeV = IsTypeAlternative::value; @@ -82,11 +83,11 @@ inline constexpr bool IsTypeAlternativeV = IsTypeAlternative::value; using TypeVariant = absl::variant; + DurationType, DynType, EnumType, ErrorType, FunctionType, + IntType, IntWrapperType, ListType, MapType, NullType, + OpaqueType, StringType, StringWrapperType, MessageType, + BasicStructType, TimestampType, TypeParamType, TypeType, + UintType, UintWrapperType, UnknownType>; using StructTypeVariant = absl::variant;