From 29d5f28f06d89a877b8ffb715623593ff9f529c7 Mon Sep 17 00:00:00 2001 From: Alvaro Aleman Date: Sun, 28 Jan 2024 20:11:17 -0500 Subject: [PATCH] DefaultTypeAdapter: Add support for missing custom scalars The default type adapter already supports some custom scalar types but not all, this change adds the missing ones. The most only used one is likely string. --- common/types/int.go | 12 +++++++++++ common/types/int_test.go | 30 ++++++++++++++++++++++++++ common/types/overflow.go | 40 +++++++++++++++++++++++++++++++++++ common/types/provider.go | 24 +++++++++++++++++++++ common/types/provider_test.go | 32 +++++++++++++++++++++++++++- common/types/string.go | 5 +---- common/types/string_test.go | 11 ++++++++++ common/types/uint.go | 12 +++++++++++ common/types/uint_test.go | 32 ++++++++++++++++++++++++++++ 9 files changed, 193 insertions(+), 5 deletions(-) diff --git a/common/types/int.go b/common/types/int.go index 940772ae..0ae9507c 100644 --- a/common/types/int.go +++ b/common/types/int.go @@ -90,6 +90,18 @@ func (i Int) ConvertToNative(typeDesc reflect.Type) (any, error) { return nil, err } return reflect.ValueOf(v).Convert(typeDesc).Interface(), nil + case reflect.Int8: + v, err := int64ToInt8Checked(int64(i)) + if err != nil { + return nil, err + } + return reflect.ValueOf(v).Convert(typeDesc).Interface(), nil + case reflect.Int16: + v, err := int64ToInt16Checked(int64(i)) + if err != nil { + return nil, err + } + return reflect.ValueOf(v).Convert(typeDesc).Interface(), nil case reflect.Int64: return reflect.ValueOf(i).Convert(typeDesc).Interface(), nil case reflect.Ptr: diff --git a/common/types/int_test.go b/common/types/int_test.go index 9f76bb77..6529843c 100644 --- a/common/types/int_test.go +++ b/common/types/int_test.go @@ -165,6 +165,36 @@ func TestIntConvertToNative_Error(t *testing.T) { } } +func TestIntConvertToNative_Int8(t *testing.T) { + val, err := Int(127).ConvertToNative(reflect.TypeOf(int8(0))) + if err != nil { + t.Fatalf("Int.ConvertToNative(int8) failed: %v", err) + } + if val.(int8) != 127 { + t.Errorf("Got '%v', expected 20050", val) + } + val, err = Int(math.MaxInt8 + 1).ConvertToNative(reflect.TypeOf(int8(0))) + if err == nil { + t.Errorf("(MaxInt+1).ConvertToNative(int8) did not error, got: %v", val) + } else if !strings.Contains(err.Error(), "integer overflow") { + t.Errorf("ConvertToNative(int8) returned unexpected error: %v, wanted integer overflow", err) + } +} +func TestIntConvertToNative_Int16(t *testing.T) { + val, err := Int(20050).ConvertToNative(reflect.TypeOf(int16(0))) + if err != nil { + t.Fatalf("Int.ConvertToNative(int16) failed: %v", err) + } + if val.(int16) != 20050 { + t.Errorf("Got '%v', expected 20050", val) + } + val, err = Int(math.MaxInt16 + 1).ConvertToNative(reflect.TypeOf(int16(0))) + if err == nil { + t.Errorf("(MaxInt+1).ConvertToNative(int16) did not error, got: %v", val) + } else if !strings.Contains(err.Error(), "integer overflow") { + t.Errorf("ConvertToNative(int32) returned unexpected error: %v, wanted integer overflow", err) + } +} func TestIntConvertToNative_Int32(t *testing.T) { val, err := Int(20050).ConvertToNative(reflect.TypeOf(int32(0))) if err != nil { diff --git a/common/types/overflow.go b/common/types/overflow.go index c68a9218..dcb66ef5 100644 --- a/common/types/overflow.go +++ b/common/types/overflow.go @@ -326,6 +326,26 @@ func int64ToUint64Checked(v int64) (uint64, error) { return uint64(v), nil } +// int64ToInt8Checked converts an int64 to an int8 value. +// +// If the conversion fails due to overflow the error return value will be non-nil. +func int64ToInt8Checked(v int64) (int8, error) { + if v < math.MinInt8 || v > math.MaxInt8 { + return 0, errIntOverflow + } + return int8(v), nil +} + +// int64ToInt16Checked converts an int64 to an int16 value. +// +// If the conversion fails due to overflow the error return value will be non-nil. +func int64ToInt16Checked(v int64) (int16, error) { + if v < math.MinInt16 || v > math.MaxInt16 { + return 0, errIntOverflow + } + return int16(v), nil +} + // int64ToInt32Checked converts an int64 to an int32 value. // // If the conversion fails due to overflow the error return value will be non-nil. @@ -336,6 +356,26 @@ func int64ToInt32Checked(v int64) (int32, error) { return int32(v), nil } +// uint64ToUint8Checked converts a uint64 to a uint8 value. +// +// If the conversion fails due to overflow the error return value will be non-nil. +func uint64ToUint8Checked(v uint64) (uint8, error) { + if v > math.MaxUint8 { + return 0, errUintOverflow + } + return uint8(v), nil +} + +// uint64ToUint16Checked converts a uint64 to a uint16 value. +// +// If the conversion fails due to overflow the error return value will be non-nil. +func uint64ToUint16Checked(v uint64) (uint16, error) { + if v > math.MaxUint16 { + return 0, errUintOverflow + } + return uint16(v), nil +} + // uint64ToUint32Checked converts a uint64 to a uint32 value. // // If the conversion fails due to overflow the error return value will be non-nil. diff --git a/common/types/provider.go b/common/types/provider.go index 5157cd1f..c5ff05fd 100644 --- a/common/types/provider.go +++ b/common/types/provider.go @@ -590,12 +590,33 @@ func nativeToValue(a Adapter, value any) (ref.Val, bool) { return NewDynamicMap(a, v), true // type aliases of primitive types cannot be asserted as that type, but rather need // to be downcast to int32 before being converted to a CEL representation. + case reflect.Bool: + boolTupe := reflect.TypeOf(false) + return Bool(refValue.Convert(boolTupe).Interface().(bool)), true + case reflect.Int: + intType := reflect.TypeOf(int(0)) + return Int(refValue.Convert(intType).Interface().(int)), true + case reflect.Int8: + intType := reflect.TypeOf(int8(0)) + return Int(refValue.Convert(intType).Interface().(int8)), true + case reflect.Int16: + intType := reflect.TypeOf(int16(0)) + return Int(refValue.Convert(intType).Interface().(int16)), true case reflect.Int32: intType := reflect.TypeOf(int32(0)) return Int(refValue.Convert(intType).Interface().(int32)), true case reflect.Int64: intType := reflect.TypeOf(int64(0)) return Int(refValue.Convert(intType).Interface().(int64)), true + case reflect.Uint: + uintType := reflect.TypeOf(uint(0)) + return Uint(refValue.Convert(uintType).Interface().(uint)), true + case reflect.Uint8: + uintType := reflect.TypeOf(uint8(0)) + return Uint(refValue.Convert(uintType).Interface().(uint8)), true + case reflect.Uint16: + uintType := reflect.TypeOf(uint16(0)) + return Uint(refValue.Convert(uintType).Interface().(uint16)), true case reflect.Uint32: uintType := reflect.TypeOf(uint32(0)) return Uint(refValue.Convert(uintType).Interface().(uint32)), true @@ -608,6 +629,9 @@ func nativeToValue(a Adapter, value any) (ref.Val, bool) { case reflect.Float64: doubleType := reflect.TypeOf(float64(0)) return Double(refValue.Convert(doubleType).Interface().(float64)), true + case reflect.String: + stringType := reflect.TypeOf("") + return String(refValue.Convert(stringType).Interface().(string)), true } } return nil, false diff --git a/common/types/provider_test.go b/common/types/provider_test.go index 56f15290..efe1244a 100644 --- a/common/types/provider_test.go +++ b/common/types/provider_test.go @@ -566,6 +566,21 @@ func TestConvertToNative(t *testing.T) { // Proto conversion tests. parsedExpr := &exprpb.ParsedExpr{} expectValueToNative(t, reg.NativeToValue(parsedExpr), parsedExpr) + + // Custom scalars + expectValueToNative(t, Int(1), testInt(1)) + expectValueToNative(t, Int(1), testInt8(1)) + expectValueToNative(t, Int(1), testInt16(1)) + expectValueToNative(t, Int(1), testInt32(1)) + expectValueToNative(t, Int(1), testInt64(1)) + expectValueToNative(t, Uint(1), testUint(1)) + expectValueToNative(t, Uint(1), testUint8(1)) + expectValueToNative(t, Uint(1), testUint16(1)) + expectValueToNative(t, Uint(1), testUint32(1)) + expectValueToNative(t, Uint(1), testUint64(1)) + expectValueToNative(t, Double(4.5), testFloat32(4.5)) + expectValueToNative(t, Double(-5.1), testFloat64(-5.1)) + expectValueToNative(t, String("foo"), testString("foo")) } func TestNativeToValue_Any(t *testing.T) { @@ -758,12 +773,19 @@ func TestNativeToValue_Primitive(t *testing.T) { expectNativeToValue(t, &rBytes, rBytes) // Extensions to core types. + expectNativeToValue(t, testInt(1), Int(1)) + expectNativeToValue(t, testInt8(1), Int(1)) + expectNativeToValue(t, testInt16(1), Int(1)) expectNativeToValue(t, testInt32(1), Int(1)) expectNativeToValue(t, testInt64(-100), Int(-100)) + expectNativeToValue(t, testUint(1), Uint(1)) + expectNativeToValue(t, testUint8(1), Uint(1)) + expectNativeToValue(t, testUint16(1), Uint(1)) expectNativeToValue(t, testUint32(2), Uint(2)) expectNativeToValue(t, testUint64(3), Uint(3)) expectNativeToValue(t, testFloat32(4.5), Double(4.5)) expectNativeToValue(t, testFloat64(-5.1), Double(-5.1)) + expectNativeToValue(t, testString("foo"), String("foo")) // Null conversion test. expectNativeToValue(t, nil, NullValue) @@ -795,7 +817,7 @@ func expectValueToNative(t *testing.T, in ref.Val, out any) { } if !equals { t.Errorf("Unexpected conversion from expr to proto.\n"+ - "expected: %T, actual: %T", val, out) + "expected: %T, actual: %T", out, val) } } } @@ -870,12 +892,20 @@ func BenchmarkTypeProviderCopy(b *testing.B) { type nonConvertible struct { Field string } +type testBool bool +type testInt int +type testInt8 int8 +type testInt16 int16 type testInt32 int32 type testInt64 int64 +type testUint uint +type testUint8 uint8 +type testUint16 uint16 type testUint32 uint32 type testUint64 uint64 type testFloat32 float32 type testFloat64 float64 +type testString string func newTestRegistry(t *testing.T, types ...proto.Message) *Registry { t.Helper() diff --git a/common/types/string.go b/common/types/string.go index a2990b26..3a93743f 100644 --- a/common/types/string.go +++ b/common/types/string.go @@ -66,10 +66,7 @@ func (s String) Compare(other ref.Val) ref.Val { func (s String) ConvertToNative(typeDesc reflect.Type) (any, error) { switch typeDesc.Kind() { case reflect.String: - if reflect.TypeOf(s).AssignableTo(typeDesc) { - return s, nil - } - return s.Value(), nil + return reflect.ValueOf(s).Convert(typeDesc).Interface(), nil case reflect.Ptr: switch typeDesc { case anyValueType: diff --git a/common/types/string_test.go b/common/types/string_test.go index 226b1932..37958535 100644 --- a/common/types/string_test.go +++ b/common/types/string_test.go @@ -106,6 +106,17 @@ func TestStringConvertToNative_String(t *testing.T) { } } +type customString string + +func TestStringConvertToNative_CustomString(t *testing.T) { + val, err := String("hello").ConvertToNative(reflect.TypeOf(customString(""))) + if err != nil { + t.Error(err) + } else if v, ok := val.(customString); !ok || v != "hello" { + t.Errorf("Got %T with val '%v', expected %T with val 'hello'", val, v, customString("")) + } +} + func TestStringConvertToNative_Wrapper(t *testing.T) { val, err := String("hello").ConvertToNative(stringWrapperType) if err != nil { diff --git a/common/types/uint.go b/common/types/uint.go index 3257f9ad..6d74f30d 100644 --- a/common/types/uint.go +++ b/common/types/uint.go @@ -80,6 +80,18 @@ func (i Uint) ConvertToNative(typeDesc reflect.Type) (any, error) { return 0, err } return reflect.ValueOf(v).Convert(typeDesc).Interface(), nil + case reflect.Uint8: + v, err := uint64ToUint8Checked(uint64(i)) + if err != nil { + return 0, err + } + return reflect.ValueOf(v).Convert(typeDesc).Interface(), nil + case reflect.Uint16: + v, err := uint64ToUint16Checked(uint64(i)) + if err != nil { + return 0, err + } + return reflect.ValueOf(v).Convert(typeDesc).Interface(), nil case reflect.Uint64: return reflect.ValueOf(i).Convert(typeDesc).Interface(), nil case reflect.Ptr: diff --git a/common/types/uint_test.go b/common/types/uint_test.go index 777d7955..f07832ca 100644 --- a/common/types/uint_test.go +++ b/common/types/uint_test.go @@ -172,6 +172,38 @@ func TestUintConvertToNative_Json(t *testing.T) { } } +func TestUintConvertToNative_Uint8(t *testing.T) { + val, err := Uint(128).ConvertToNative(reflect.TypeOf(uint8(0))) + if err != nil { + t.Fatalf("Uint.ConvertToNative(uint8) failed: %v", err) + } + if val.(uint8) != 128 { + t.Errorf("Got '%v', expected 128", val) + } + val, err = Uint(math.MaxUint8 + 1).ConvertToNative(reflect.TypeOf(uint8(0))) + if err == nil { + t.Errorf("(MaxUint+1).ConvertToNative(uint8) did not error, got: %v", val) + } else if !strings.Contains(err.Error(), "unsigned integer overflow") { + t.Errorf("ConvertToNative(uint8) returned unexpected error: %v, wanted unsigned integer overflow", err) + } +} + +func TestUintConvertToNative_Uint16(t *testing.T) { + val, err := Uint(20050).ConvertToNative(reflect.TypeOf(uint16(0))) + if err != nil { + t.Fatalf("Uint.ConvertToNative(uint16) failed: %v", err) + } + if val.(uint16) != 20050 { + t.Errorf("Got '%v', expected 20050", val) + } + val, err = Uint(math.MaxUint16 + 1).ConvertToNative(reflect.TypeOf(uint16(0))) + if err == nil { + t.Errorf("(MaxUint+1).ConvertToNative(uint16) did not error, got: %v", val) + } else if !strings.Contains(err.Error(), "unsigned integer overflow") { + t.Errorf("ConvertToNative(uint16) returned unexpected error: %v, wanted unsigned integer overflow", err) + } +} + func TestUintConvertToNative_Uint32(t *testing.T) { val, err := Uint(20050).ConvertToNative(reflect.TypeOf(uint32(0))) if err != nil {