Skip to content

Commit

Permalink
Add EnumType which covers the enum proposal, which is type only
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 663846915
  • Loading branch information
jcking authored and copybara-github committed Aug 16, 2024
1 parent 0867090 commit 25bc83d
Show file tree
Hide file tree
Showing 10 changed files with 379 additions and 24 deletions.
2 changes: 2 additions & 0 deletions common/kind.cc
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,8 @@ absl::string_view KindToString(Kind kind) {
return "google.protobuf.StringValue";
case Kind::kBytesWrapper:
return "google.protobuf.BytesValue";
case Kind::kEnum:
return "enum";
default:
return "*error*";
}
Expand Down
1 change: 1 addition & 0 deletions common/kind.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ enum class Kind /* : uint8_t */ {

kTypeParam,
kFunction,
kEnum,

// Legacy aliases, deprecated do not use.
kNullType = kNull,
Expand Down
33 changes: 25 additions & 8 deletions common/type.cc
Original file line number Diff line number Diff line change
Expand Up @@ -69,19 +69,27 @@ Type Type::Message(absl::Nonnull<const Descriptor*> descriptor) {
return MessageType(descriptor);
}
}

Type Type::Enum(absl::Nonnull<const google::protobuf::EnumDescriptor*> descriptor) {
if (descriptor->full_name() == "google.protobuf.NullValue") {
return NullType();
}
return EnumType(descriptor);
}

namespace {

static constexpr std::array<TypeKind, 28> kTypeToKindArray = {
static constexpr std::array<TypeKind, 29> kTypeToKindArray = {
TypeKind::kError, TypeKind::kAny, TypeKind::kBool,
TypeKind::kBoolWrapper, TypeKind::kBytes, TypeKind::kBytesWrapper,
TypeKind::kDouble, TypeKind::kDoubleWrapper, TypeKind::kDuration,
TypeKind::kDyn, TypeKind::kError, TypeKind::kFunction,
TypeKind::kInt, TypeKind::kIntWrapper, TypeKind::kList,
TypeKind::kMap, TypeKind::kNull, TypeKind::kOpaque,
TypeKind::kString, TypeKind::kStringWrapper, TypeKind::kStruct,
TypeKind::kStruct, TypeKind::kTimestamp, TypeKind::kTypeParam,
TypeKind::kType, TypeKind::kUint, TypeKind::kUintWrapper,
TypeKind::kUnknown};
TypeKind::kDyn, TypeKind::kEnum, TypeKind::kError,
TypeKind::kFunction, TypeKind::kInt, TypeKind::kIntWrapper,
TypeKind::kList, TypeKind::kMap, TypeKind::kNull,
TypeKind::kOpaque, TypeKind::kString, TypeKind::kStringWrapper,
TypeKind::kStruct, TypeKind::kStruct, TypeKind::kTimestamp,
TypeKind::kTypeParam, TypeKind::kType, TypeKind::kUint,
TypeKind::kUintWrapper, TypeKind::kUnknown};

static_assert(kTypeToKindArray.size() ==
absl::variant_size<common_internal::TypeVariant>(),
Expand Down Expand Up @@ -222,6 +230,10 @@ absl::optional<DynType> Type::AsDyn() const {
return GetOrNullopt<DynType>(variant_);
}

absl::optional<EnumType> Type::AsEnum() const {
return GetOrNullopt<EnumType>(variant_);
}

absl::optional<ErrorType> Type::AsError() const {
return GetOrNullopt<ErrorType>(variant_);
}
Expand Down Expand Up @@ -363,6 +375,11 @@ Type::operator DynType() const {
return GetOrDie<DynType>(variant_);
}

Type::operator EnumType() const {
ABSL_DCHECK(IsEnum()) << DebugString();
return GetOrDie<EnumType>(variant_);
}

Type::operator ErrorType() const {
ABSL_DCHECK(IsError()) << DebugString();
return GetOrDie<ErrorType>(variant_);
Expand Down
24 changes: 24 additions & 0 deletions common/type.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
#include "common/types/double_wrapper_type.h" // IWYU pragma: export
#include "common/types/duration_type.h" // IWYU pragma: export
#include "common/types/dyn_type.h" // IWYU pragma: export
#include "common/types/enum_type.h" // IWYU pragma: export
#include "common/types/error_type.h" // IWYU pragma: export
#include "common/types/function_type.h" // IWYU pragma: export
#include "common/types/int_type.h" // IWYU pragma: export
Expand Down Expand Up @@ -86,6 +87,12 @@ class Type final {
static Type Message(absl::Nonnull<const google::protobuf::Descriptor*> descriptor
ABSL_ATTRIBUTE_LIFETIME_BOUND);

// Returns an appropriate `Type` for the dynamic protobuf enum. For well
// known enum types, the appropriate `Type` is returned. All others return
// `EnumType`.
static Type Enum(absl::Nonnull<const google::protobuf::EnumDescriptor*> descriptor
ABSL_ATTRIBUTE_LIFETIME_BOUND);

Type() = default;
Type(const Type&) = default;
Type(Type&&) = default;
Expand Down Expand Up @@ -205,6 +212,8 @@ class Type final {

bool IsDyn() const { return absl::holds_alternative<DynType>(variant_); }

bool IsEnum() const { return absl::holds_alternative<EnumType>(variant_); }

bool IsError() const { return absl::holds_alternative<ErrorType>(variant_); }

bool IsFunction() const {
Expand Down Expand Up @@ -317,6 +326,11 @@ class Type final {
return IsDyn();
}

template <typename T>
std::enable_if_t<std::is_same_v<EnumType, T>, bool> Is() const {
return IsEnum();
}

template <typename T>
std::enable_if_t<std::is_same_v<ErrorType, T>, bool> Is() const {
return IsError();
Expand Down Expand Up @@ -430,6 +444,8 @@ class Type final {

absl::optional<DynType> AsDyn() const;

absl::optional<EnumType> AsEnum() const;

absl::optional<ErrorType> AsError() const;

absl::optional<FunctionType> AsFunction() const;
Expand Down Expand Up @@ -534,6 +550,12 @@ class Type final {
return AsDyn();
}

template <typename T>
std::enable_if_t<std::is_same_v<EnumType, T>, absl::optional<EnumType>> As()
const {
return AsEnum();
}

template <typename T>
std::enable_if_t<std::is_same_v<ErrorType, T>, absl::optional<ErrorType>> As()
const {
Expand Down Expand Up @@ -677,6 +699,8 @@ class Type final {

explicit operator DynType() const;

explicit operator EnumType() const;

explicit operator ErrorType() const;

explicit operator FunctionType() const;
Expand Down
1 change: 1 addition & 0 deletions common/type_kind.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ enum class TypeKind : std::underlying_type_t<Kind> {

kTypeParam = static_cast<int>(Kind::kTypeParam),
kFunction = static_cast<int>(Kind::kFunction),
kEnum = static_cast<int>(Kind::kEnum),

// Legacy aliases, deprecated do not use.
kNullType = kNull,
Expand Down
65 changes: 64 additions & 1 deletion common/type_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,17 +16,32 @@

#include "absl/hash/hash.h"
#include "absl/hash/hash_testing.h"
#include "common/native_type.h"
#include "absl/log/die_if_null.h"
#include "internal/testing.h"
#include "internal/testing_descriptor_pool.h"
#include "google/protobuf/arena.h"

namespace cel {
namespace {

using ::cel::internal::GetTestingDescriptorPool;
using testing::_;
using testing::An;
using testing::Optional;

TEST(Type, Enum) {
EXPECT_EQ(
Type::Enum(
ABSL_DIE_IF_NULL(GetTestingDescriptorPool()->FindEnumTypeByName(
"google.api.expr.test.v1.proto3.TestAllTypes.NestedEnum"))),
EnumType(ABSL_DIE_IF_NULL(GetTestingDescriptorPool()->FindEnumTypeByName(
"google.api.expr.test.v1.proto3.TestAllTypes.NestedEnum"))));
EXPECT_EQ(Type::Enum(
ABSL_DIE_IF_NULL(GetTestingDescriptorPool()->FindEnumTypeByName(
"google.protobuf.NullValue"))),
NullType());
}

TEST(Type, KindDebugDeath) {
Type type;
static_cast<void>(type);
Expand Down Expand Up @@ -75,6 +90,12 @@ TEST(Type, Is) {

EXPECT_TRUE(Type(DynType()).Is<DynType>());

EXPECT_TRUE(
Type(EnumType(
ABSL_DIE_IF_NULL(GetTestingDescriptorPool()->FindEnumTypeByName(
"google.api.expr.test.v1.proto3.TestAllTypes.NestedEnum"))))
.Is<EnumType>());

EXPECT_TRUE(Type(ErrorType()).Is<ErrorType>());

EXPECT_TRUE(Type(FunctionType(&arena, DynType(), {})).Is<FunctionType>());
Expand All @@ -88,6 +109,15 @@ TEST(Type, Is) {

EXPECT_TRUE(Type(MapType()).Is<MapType>());

EXPECT_TRUE(Type(MessageType(ABSL_DIE_IF_NULL(
GetTestingDescriptorPool()->FindMessageTypeByName(
"google.api.expr.test.v1.proto3.TestAllTypes"))))
.IsStruct());
EXPECT_TRUE(Type(MessageType(ABSL_DIE_IF_NULL(
GetTestingDescriptorPool()->FindMessageTypeByName(
"google.api.expr.test.v1.proto3.TestAllTypes"))))
.IsMessage());

EXPECT_TRUE(Type(NullType()).Is<NullType>());

EXPECT_TRUE(Type(OptionalType()).Is<OpaqueType>());
Expand Down Expand Up @@ -133,6 +163,13 @@ TEST(Type, As) {

EXPECT_THAT(Type(DynType()).As<DynType>(), Optional(An<DynType>()));

EXPECT_THAT(
Type(EnumType(
ABSL_DIE_IF_NULL(GetTestingDescriptorPool()->FindEnumTypeByName(
"google.api.expr.test.v1.proto3.TestAllTypes.NestedEnum"))))
.As<EnumType>(),
Optional(An<EnumType>()));

EXPECT_THAT(Type(ErrorType()).As<ErrorType>(), Optional(An<ErrorType>()));

EXPECT_TRUE(Type(FunctionType(&arena, DynType(), {})).Is<FunctionType>());
Expand All @@ -146,6 +183,17 @@ TEST(Type, As) {

EXPECT_THAT(Type(MapType()).As<MapType>(), Optional(An<MapType>()));

EXPECT_THAT(Type(MessageType(ABSL_DIE_IF_NULL(
GetTestingDescriptorPool()->FindMessageTypeByName(
"google.api.expr.test.v1.proto3.TestAllTypes"))))
.As<StructType>(),
Optional(An<StructType>()));
EXPECT_THAT(Type(MessageType(ABSL_DIE_IF_NULL(
GetTestingDescriptorPool()->FindMessageTypeByName(
"google.api.expr.test.v1.proto3.TestAllTypes"))))
.As<MessageType>(),
Optional(An<MessageType>()));

EXPECT_THAT(Type(NullType()).As<NullType>(), Optional(An<NullType>()));

EXPECT_THAT(Type(OptionalType()).As<OptionalType>(),
Expand Down Expand Up @@ -201,6 +249,12 @@ TEST(Type, Cast) {

EXPECT_THAT(static_cast<DynType>(Type(DynType())), An<DynType>());

EXPECT_THAT(
static_cast<EnumType>(Type(EnumType(
ABSL_DIE_IF_NULL(GetTestingDescriptorPool()->FindEnumTypeByName(
"google.api.expr.test.v1.proto3.TestAllTypes.NestedEnum"))))),
An<EnumType>());

EXPECT_THAT(static_cast<ErrorType>(Type(ErrorType())), An<ErrorType>());

EXPECT_TRUE(Type(FunctionType(&arena, DynType(), {})).Is<FunctionType>());
Expand All @@ -216,6 +270,15 @@ TEST(Type, Cast) {

EXPECT_THAT(static_cast<MapType>(Type(MapType())), An<MapType>());

EXPECT_THAT(static_cast<StructType>(Type(MessageType(ABSL_DIE_IF_NULL(
GetTestingDescriptorPool()->FindMessageTypeByName(
"google.api.expr.test.v1.proto3.TestAllTypes"))))),
An<StructType>());
EXPECT_THAT(static_cast<MessageType>(Type(MessageType(ABSL_DIE_IF_NULL(
GetTestingDescriptorPool()->FindMessageTypeByName(
"google.api.expr.test.v1.proto3.TestAllTypes"))))),
An<MessageType>());

EXPECT_THAT(static_cast<NullType>(Type(NullType())), An<NullType>());

EXPECT_THAT(static_cast<OptionalType>(Type(OptionalType())),
Expand Down
43 changes: 43 additions & 0 deletions common/types/enum_type.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
// Copyright 2024 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include <string>

#include "absl/base/nullability.h"
#include "absl/base/optimization.h"
#include "absl/strings/str_cat.h"
#include "common/type.h"
#include "google/protobuf/descriptor.h"

namespace cel {

using google::protobuf::EnumDescriptor;

bool IsWellKnownEnumType(absl::Nonnull<const EnumDescriptor*> descriptor) {
return descriptor->full_name() == "google.protobuf.NullValue";
}

std::string EnumType::DebugString() const {
if (ABSL_PREDICT_TRUE(static_cast<bool>(*this))) {
static_assert(sizeof(descriptor_) == 8 || sizeof(descriptor_) == 4,
"sizeof(void*) is neither 8 nor 4");
return absl::StrCat(name(), "@0x",
absl::Hex(descriptor_, sizeof(descriptor_) == 8
? absl::PadSpec::kZeroPad16
: absl::PadSpec::kZeroPad8));
}
return std::string();
}

} // namespace cel
Loading

0 comments on commit 25bc83d

Please sign in to comment.