From 19b2ad1eb2f51cd23b36dfb68f929a7446382378 Mon Sep 17 00:00:00 2001
From: Alvaro Aleman <alvaroaleman@users.noreply.github.com>
Date: Fri, 2 Feb 2024 13:06:20 -0500
Subject: [PATCH] DefaultTypeAdapter: Add support for missing custom scalars
 (#893)

The default type adapter already supports some custom scalar types but
not all, this change adds the missing ones. The most only used one is
likely string.
---
 common/types/int.go           | 12 +++++++++++
 common/types/int_test.go      | 30 ++++++++++++++++++++++++++
 common/types/overflow.go      | 40 +++++++++++++++++++++++++++++++++++
 common/types/provider.go      | 24 +++++++++++++++++++++
 common/types/provider_test.go | 32 +++++++++++++++++++++++++++-
 common/types/string.go        |  5 +----
 common/types/string_test.go   | 11 ++++++++++
 common/types/uint.go          | 12 +++++++++++
 common/types/uint_test.go     | 32 ++++++++++++++++++++++++++++
 9 files changed, 193 insertions(+), 5 deletions(-)

diff --git a/common/types/int.go b/common/types/int.go
index 940772ae..0ae9507c 100644
--- a/common/types/int.go
+++ b/common/types/int.go
@@ -90,6 +90,18 @@ func (i Int) ConvertToNative(typeDesc reflect.Type) (any, error) {
 			return nil, err
 		}
 		return reflect.ValueOf(v).Convert(typeDesc).Interface(), nil
+	case reflect.Int8:
+		v, err := int64ToInt8Checked(int64(i))
+		if err != nil {
+			return nil, err
+		}
+		return reflect.ValueOf(v).Convert(typeDesc).Interface(), nil
+	case reflect.Int16:
+		v, err := int64ToInt16Checked(int64(i))
+		if err != nil {
+			return nil, err
+		}
+		return reflect.ValueOf(v).Convert(typeDesc).Interface(), nil
 	case reflect.Int64:
 		return reflect.ValueOf(i).Convert(typeDesc).Interface(), nil
 	case reflect.Ptr:
diff --git a/common/types/int_test.go b/common/types/int_test.go
index 9f76bb77..6529843c 100644
--- a/common/types/int_test.go
+++ b/common/types/int_test.go
@@ -165,6 +165,36 @@ func TestIntConvertToNative_Error(t *testing.T) {
 	}
 }
 
+func TestIntConvertToNative_Int8(t *testing.T) {
+	val, err := Int(127).ConvertToNative(reflect.TypeOf(int8(0)))
+	if err != nil {
+		t.Fatalf("Int.ConvertToNative(int8) failed: %v", err)
+	}
+	if val.(int8) != 127 {
+		t.Errorf("Got '%v', expected 20050", val)
+	}
+	val, err = Int(math.MaxInt8 + 1).ConvertToNative(reflect.TypeOf(int8(0)))
+	if err == nil {
+		t.Errorf("(MaxInt+1).ConvertToNative(int8) did not error, got: %v", val)
+	} else if !strings.Contains(err.Error(), "integer overflow") {
+		t.Errorf("ConvertToNative(int8) returned unexpected error: %v, wanted integer overflow", err)
+	}
+}
+func TestIntConvertToNative_Int16(t *testing.T) {
+	val, err := Int(20050).ConvertToNative(reflect.TypeOf(int16(0)))
+	if err != nil {
+		t.Fatalf("Int.ConvertToNative(int16) failed: %v", err)
+	}
+	if val.(int16) != 20050 {
+		t.Errorf("Got '%v', expected 20050", val)
+	}
+	val, err = Int(math.MaxInt16 + 1).ConvertToNative(reflect.TypeOf(int16(0)))
+	if err == nil {
+		t.Errorf("(MaxInt+1).ConvertToNative(int16) did not error, got: %v", val)
+	} else if !strings.Contains(err.Error(), "integer overflow") {
+		t.Errorf("ConvertToNative(int32) returned unexpected error: %v, wanted integer overflow", err)
+	}
+}
 func TestIntConvertToNative_Int32(t *testing.T) {
 	val, err := Int(20050).ConvertToNative(reflect.TypeOf(int32(0)))
 	if err != nil {
diff --git a/common/types/overflow.go b/common/types/overflow.go
index c68a9218..dcb66ef5 100644
--- a/common/types/overflow.go
+++ b/common/types/overflow.go
@@ -326,6 +326,26 @@ func int64ToUint64Checked(v int64) (uint64, error) {
 	return uint64(v), nil
 }
 
+// int64ToInt8Checked converts an int64 to an int8 value.
+//
+// If the conversion fails due to overflow the error return value will be non-nil.
+func int64ToInt8Checked(v int64) (int8, error) {
+	if v < math.MinInt8 || v > math.MaxInt8 {
+		return 0, errIntOverflow
+	}
+	return int8(v), nil
+}
+
+// int64ToInt16Checked converts an int64 to an int16 value.
+//
+// If the conversion fails due to overflow the error return value will be non-nil.
+func int64ToInt16Checked(v int64) (int16, error) {
+	if v < math.MinInt16 || v > math.MaxInt16 {
+		return 0, errIntOverflow
+	}
+	return int16(v), nil
+}
+
 // int64ToInt32Checked converts an int64 to an int32 value.
 //
 // If the conversion fails due to overflow the error return value will be non-nil.
@@ -336,6 +356,26 @@ func int64ToInt32Checked(v int64) (int32, error) {
 	return int32(v), nil
 }
 
+// uint64ToUint8Checked converts a uint64 to a uint8 value.
+//
+// If the conversion fails due to overflow the error return value will be non-nil.
+func uint64ToUint8Checked(v uint64) (uint8, error) {
+	if v > math.MaxUint8 {
+		return 0, errUintOverflow
+	}
+	return uint8(v), nil
+}
+
+// uint64ToUint16Checked converts a uint64 to a uint16 value.
+//
+// If the conversion fails due to overflow the error return value will be non-nil.
+func uint64ToUint16Checked(v uint64) (uint16, error) {
+	if v > math.MaxUint16 {
+		return 0, errUintOverflow
+	}
+	return uint16(v), nil
+}
+
 // uint64ToUint32Checked converts a uint64 to a uint32 value.
 //
 // If the conversion fails due to overflow the error return value will be non-nil.
diff --git a/common/types/provider.go b/common/types/provider.go
index 5157cd1f..c5ff05fd 100644
--- a/common/types/provider.go
+++ b/common/types/provider.go
@@ -590,12 +590,33 @@ func nativeToValue(a Adapter, value any) (ref.Val, bool) {
 			return NewDynamicMap(a, v), true
 		// type aliases of primitive types cannot be asserted as that type, but rather need
 		// to be downcast to int32 before being converted to a CEL representation.
+		case reflect.Bool:
+			boolTupe := reflect.TypeOf(false)
+			return Bool(refValue.Convert(boolTupe).Interface().(bool)), true
+		case reflect.Int:
+			intType := reflect.TypeOf(int(0))
+			return Int(refValue.Convert(intType).Interface().(int)), true
+		case reflect.Int8:
+			intType := reflect.TypeOf(int8(0))
+			return Int(refValue.Convert(intType).Interface().(int8)), true
+		case reflect.Int16:
+			intType := reflect.TypeOf(int16(0))
+			return Int(refValue.Convert(intType).Interface().(int16)), true
 		case reflect.Int32:
 			intType := reflect.TypeOf(int32(0))
 			return Int(refValue.Convert(intType).Interface().(int32)), true
 		case reflect.Int64:
 			intType := reflect.TypeOf(int64(0))
 			return Int(refValue.Convert(intType).Interface().(int64)), true
+		case reflect.Uint:
+			uintType := reflect.TypeOf(uint(0))
+			return Uint(refValue.Convert(uintType).Interface().(uint)), true
+		case reflect.Uint8:
+			uintType := reflect.TypeOf(uint8(0))
+			return Uint(refValue.Convert(uintType).Interface().(uint8)), true
+		case reflect.Uint16:
+			uintType := reflect.TypeOf(uint16(0))
+			return Uint(refValue.Convert(uintType).Interface().(uint16)), true
 		case reflect.Uint32:
 			uintType := reflect.TypeOf(uint32(0))
 			return Uint(refValue.Convert(uintType).Interface().(uint32)), true
@@ -608,6 +629,9 @@ func nativeToValue(a Adapter, value any) (ref.Val, bool) {
 		case reflect.Float64:
 			doubleType := reflect.TypeOf(float64(0))
 			return Double(refValue.Convert(doubleType).Interface().(float64)), true
+		case reflect.String:
+			stringType := reflect.TypeOf("")
+			return String(refValue.Convert(stringType).Interface().(string)), true
 		}
 	}
 	return nil, false
diff --git a/common/types/provider_test.go b/common/types/provider_test.go
index 56f15290..efe1244a 100644
--- a/common/types/provider_test.go
+++ b/common/types/provider_test.go
@@ -566,6 +566,21 @@ func TestConvertToNative(t *testing.T) {
 	// Proto conversion tests.
 	parsedExpr := &exprpb.ParsedExpr{}
 	expectValueToNative(t, reg.NativeToValue(parsedExpr), parsedExpr)
+
+	// Custom scalars
+	expectValueToNative(t, Int(1), testInt(1))
+	expectValueToNative(t, Int(1), testInt8(1))
+	expectValueToNative(t, Int(1), testInt16(1))
+	expectValueToNative(t, Int(1), testInt32(1))
+	expectValueToNative(t, Int(1), testInt64(1))
+	expectValueToNative(t, Uint(1), testUint(1))
+	expectValueToNative(t, Uint(1), testUint8(1))
+	expectValueToNative(t, Uint(1), testUint16(1))
+	expectValueToNative(t, Uint(1), testUint32(1))
+	expectValueToNative(t, Uint(1), testUint64(1))
+	expectValueToNative(t, Double(4.5), testFloat32(4.5))
+	expectValueToNative(t, Double(-5.1), testFloat64(-5.1))
+	expectValueToNative(t, String("foo"), testString("foo"))
 }
 
 func TestNativeToValue_Any(t *testing.T) {
@@ -758,12 +773,19 @@ func TestNativeToValue_Primitive(t *testing.T) {
 	expectNativeToValue(t, &rBytes, rBytes)
 
 	// Extensions to core types.
+	expectNativeToValue(t, testInt(1), Int(1))
+	expectNativeToValue(t, testInt8(1), Int(1))
+	expectNativeToValue(t, testInt16(1), Int(1))
 	expectNativeToValue(t, testInt32(1), Int(1))
 	expectNativeToValue(t, testInt64(-100), Int(-100))
+	expectNativeToValue(t, testUint(1), Uint(1))
+	expectNativeToValue(t, testUint8(1), Uint(1))
+	expectNativeToValue(t, testUint16(1), Uint(1))
 	expectNativeToValue(t, testUint32(2), Uint(2))
 	expectNativeToValue(t, testUint64(3), Uint(3))
 	expectNativeToValue(t, testFloat32(4.5), Double(4.5))
 	expectNativeToValue(t, testFloat64(-5.1), Double(-5.1))
+	expectNativeToValue(t, testString("foo"), String("foo"))
 
 	// Null conversion test.
 	expectNativeToValue(t, nil, NullValue)
@@ -795,7 +817,7 @@ func expectValueToNative(t *testing.T, in ref.Val, out any) {
 		}
 		if !equals {
 			t.Errorf("Unexpected conversion from expr to proto.\n"+
-				"expected: %T, actual: %T", val, out)
+				"expected: %T, actual: %T", out, val)
 		}
 	}
 }
@@ -870,12 +892,20 @@ func BenchmarkTypeProviderCopy(b *testing.B) {
 type nonConvertible struct {
 	Field string
 }
+type testBool bool
+type testInt int
+type testInt8 int8
+type testInt16 int16
 type testInt32 int32
 type testInt64 int64
+type testUint uint
+type testUint8 uint8
+type testUint16 uint16
 type testUint32 uint32
 type testUint64 uint64
 type testFloat32 float32
 type testFloat64 float64
+type testString string
 
 func newTestRegistry(t *testing.T, types ...proto.Message) *Registry {
 	t.Helper()
diff --git a/common/types/string.go b/common/types/string.go
index a2990b26..3a93743f 100644
--- a/common/types/string.go
+++ b/common/types/string.go
@@ -66,10 +66,7 @@ func (s String) Compare(other ref.Val) ref.Val {
 func (s String) ConvertToNative(typeDesc reflect.Type) (any, error) {
 	switch typeDesc.Kind() {
 	case reflect.String:
-		if reflect.TypeOf(s).AssignableTo(typeDesc) {
-			return s, nil
-		}
-		return s.Value(), nil
+		return reflect.ValueOf(s).Convert(typeDesc).Interface(), nil
 	case reflect.Ptr:
 		switch typeDesc {
 		case anyValueType:
diff --git a/common/types/string_test.go b/common/types/string_test.go
index 226b1932..37958535 100644
--- a/common/types/string_test.go
+++ b/common/types/string_test.go
@@ -106,6 +106,17 @@ func TestStringConvertToNative_String(t *testing.T) {
 	}
 }
 
+type customString string
+
+func TestStringConvertToNative_CustomString(t *testing.T) {
+	val, err := String("hello").ConvertToNative(reflect.TypeOf(customString("")))
+	if err != nil {
+		t.Error(err)
+	} else if v, ok := val.(customString); !ok || v != "hello" {
+		t.Errorf("Got %T with val '%v', expected %T with val 'hello'", val, v, customString(""))
+	}
+}
+
 func TestStringConvertToNative_Wrapper(t *testing.T) {
 	val, err := String("hello").ConvertToNative(stringWrapperType)
 	if err != nil {
diff --git a/common/types/uint.go b/common/types/uint.go
index 3257f9ad..6d74f30d 100644
--- a/common/types/uint.go
+++ b/common/types/uint.go
@@ -80,6 +80,18 @@ func (i Uint) ConvertToNative(typeDesc reflect.Type) (any, error) {
 			return 0, err
 		}
 		return reflect.ValueOf(v).Convert(typeDesc).Interface(), nil
+	case reflect.Uint8:
+		v, err := uint64ToUint8Checked(uint64(i))
+		if err != nil {
+			return 0, err
+		}
+		return reflect.ValueOf(v).Convert(typeDesc).Interface(), nil
+	case reflect.Uint16:
+		v, err := uint64ToUint16Checked(uint64(i))
+		if err != nil {
+			return 0, err
+		}
+		return reflect.ValueOf(v).Convert(typeDesc).Interface(), nil
 	case reflect.Uint64:
 		return reflect.ValueOf(i).Convert(typeDesc).Interface(), nil
 	case reflect.Ptr:
diff --git a/common/types/uint_test.go b/common/types/uint_test.go
index 777d7955..f07832ca 100644
--- a/common/types/uint_test.go
+++ b/common/types/uint_test.go
@@ -172,6 +172,38 @@ func TestUintConvertToNative_Json(t *testing.T) {
 	}
 }
 
+func TestUintConvertToNative_Uint8(t *testing.T) {
+	val, err := Uint(128).ConvertToNative(reflect.TypeOf(uint8(0)))
+	if err != nil {
+		t.Fatalf("Uint.ConvertToNative(uint8) failed: %v", err)
+	}
+	if val.(uint8) != 128 {
+		t.Errorf("Got '%v', expected 128", val)
+	}
+	val, err = Uint(math.MaxUint8 + 1).ConvertToNative(reflect.TypeOf(uint8(0)))
+	if err == nil {
+		t.Errorf("(MaxUint+1).ConvertToNative(uint8) did not error, got: %v", val)
+	} else if !strings.Contains(err.Error(), "unsigned integer overflow") {
+		t.Errorf("ConvertToNative(uint8) returned unexpected error: %v, wanted unsigned integer overflow", err)
+	}
+}
+
+func TestUintConvertToNative_Uint16(t *testing.T) {
+	val, err := Uint(20050).ConvertToNative(reflect.TypeOf(uint16(0)))
+	if err != nil {
+		t.Fatalf("Uint.ConvertToNative(uint16) failed: %v", err)
+	}
+	if val.(uint16) != 20050 {
+		t.Errorf("Got '%v', expected 20050", val)
+	}
+	val, err = Uint(math.MaxUint16 + 1).ConvertToNative(reflect.TypeOf(uint16(0)))
+	if err == nil {
+		t.Errorf("(MaxUint+1).ConvertToNative(uint16) did not error, got: %v", val)
+	} else if !strings.Contains(err.Error(), "unsigned integer overflow") {
+		t.Errorf("ConvertToNative(uint16) returned unexpected error: %v, wanted unsigned integer overflow", err)
+	}
+}
+
 func TestUintConvertToNative_Uint32(t *testing.T) {
 	val, err := Uint(20050).ConvertToNative(reflect.TypeOf(uint32(0)))
 	if err != nil {