From bbfa21c2bfa311162f388c3468375cbaf32bbfa1 Mon Sep 17 00:00:00 2001 From: Joel Hendrix Date: Thu, 3 Feb 2022 13:40:30 -0800 Subject: [PATCH] Convert to package, JSON null funcs to generics NullValue now takes a generic type parameter instead of an interface arg to determine the type of null sentinel value to create. IsNullValue infers its generic parameter to determine the type of null sentinel value to look for. At present, there is no way to express a 'nillable' generic type constraint so the funcs simply take a [T any] which should be fine as they typically take/return pointer-to-types. The 'to' package has been reduced to two funcs. --- sdk/azcore/core.go | 22 +++-- sdk/azcore/core_test.go | 42 +++------ sdk/azcore/example_test.go | 2 +- sdk/azcore/to/to.go | 104 ++--------------------- sdk/azcore/to/to_test.go | 169 ++----------------------------------- 5 files changed, 36 insertions(+), 303 deletions(-) diff --git a/sdk/azcore/core.go b/sdk/azcore/core.go index b8bf33df5b33..ec313526c0ba 100644 --- a/sdk/azcore/core.go +++ b/sdk/azcore/core.go @@ -23,14 +23,16 @@ type TokenCredential = shared.TokenCredential // holds sentinel values used to send nulls var nullables map[reflect.Type]interface{} = map[reflect.Type]interface{}{} +func typeOfT[T any]() reflect.Type { + // you can't, at present, obtain the type of + // a type parameter, so this is the trick + return reflect.TypeOf((*T)(nil)).Elem() +} + // NullValue is used to send an explicit 'null' within a request. // This is typically used in JSON-MERGE-PATCH operations to delete a value. -func NullValue(v interface{}) interface{} { - t := reflect.TypeOf(v) - if k := t.Kind(); k != reflect.Ptr && k != reflect.Slice && k != reflect.Map { - // t is not of pointer type, make it be of pointer type - t = reflect.PtrTo(t) - } +func NullValue[T any]() T { + t := typeOfT[T]() v, found := nullables[t] if !found { var o reflect.Value @@ -48,18 +50,14 @@ func NullValue(v interface{}) interface{} { nullables[t] = v } // return the sentinel object - return v + return v.(T) } // IsNullValue returns true if the field contains a null sentinel value. // This is used by custom marshallers to properly encode a null value. -func IsNullValue(v interface{}) bool { +func IsNullValue[T any](v T) bool { // see if our map has a sentinel object for this *T t := reflect.TypeOf(v) - if k := t.Kind(); k != reflect.Ptr && k != reflect.Slice && k != reflect.Map { - // v isn't a pointer type so it can never be a null - return false - } if o, found := nullables[t]; found { o1 := reflect.ValueOf(o) v1 := reflect.ValueOf(v) diff --git a/sdk/azcore/core_test.go b/sdk/azcore/core_test.go index 2cbd1d893de2..56b8b5504542 100644 --- a/sdk/azcore/core_test.go +++ b/sdk/azcore/core_test.go @@ -12,24 +12,11 @@ import ( ) func TestNullValue(t *testing.T) { - v := NullValue("") - if _, ok := v.(*string); !ok { - t.Fatalf("unexpected type %T", v) - } - vv := NullValue((*string)(nil)) - if _, ok := vv.(*string); !ok { - t.Fatalf("unexpected type %T", vv) - } + v := NullValue[*string]() + vv := NullValue[*string]() if v != vv { t.Fatal("null values should match for the same types") } - i := NullValue(1) - if _, ok := i.(*int); !ok { - t.Fatalf("unexpected type %T", v) - } - if v == i { - t.Fatal("null values for string and int should not match") - } } func TestIsNullValue(t *testing.T) { @@ -44,7 +31,7 @@ func TestIsNullValue(t *testing.T) { if IsNullValue(i) { t.Fatal("i isn't a null value") } - i = NullValue(0).(*int) + i = NullValue[*int]() if !IsNullValue(i) { t.Fatal("expected null value for i") } @@ -56,21 +43,12 @@ func TestIsNullValue(t *testing.T) { } func TestNullValueMapSlice(t *testing.T) { - v := NullValue([]string{}) - if _, ok := v.([]string); !ok { - t.Fatalf("unexpected type %T", v) - } - vv := NullValue(([]string)(nil)) - if _, ok := vv.([]string); !ok { - t.Fatalf("unexpected type %T", vv) - } + v := NullValue[[]string]() + vv := NullValue[[]string]() if reflect.TypeOf(v) != reflect.TypeOf(vv) { t.Fatal("null values should match for the same types") } - m := NullValue(map[string]int{}) - if _, ok := m.(map[string]int); !ok { - t.Fatalf("unexpected type %T", m) - } + m := NullValue[map[string]int]() if reflect.TypeOf(v) == reflect.TypeOf(m) { t.Fatal("null values for string and int should not match") } @@ -83,11 +61,11 @@ func TestIsNullValueMapSlice(t *testing.T) { if IsNullValue(map[int]string{}) { t.Fatal("map literal can't be a null value") } - s := NullValue([]int{}).([]int) + s := NullValue[[]int]() if !IsNullValue(s) { t.Fatal("expected null value for s") } - m := NullValue(map[string]interface{}{}).(map[string]interface{}) + m := NullValue[map[string]interface{}]() if !IsNullValue(m) { t.Fatal("expected null value for s") } @@ -114,8 +92,8 @@ func TestIsNullValueMapSlice(t *testing.T) { t.Fatal("unexpected null slice") } - nf.Map = NullValue(map[string]int{}).(map[string]int) - nf.Slice = NullValue([]string{}).([]string) + nf.Map = NullValue[map[string]int]() + nf.Slice = NullValue[[]string]() if !IsNullValue(nf.Map) { t.Fatal("expected null map") } diff --git a/sdk/azcore/example_test.go b/sdk/azcore/example_test.go index 434051bcefe9..955f7777fd40 100644 --- a/sdk/azcore/example_test.go +++ b/sdk/azcore/example_test.go @@ -56,7 +56,7 @@ func (w Widget) MarshalJSON() ([]byte, error) { func ExampleNullValue() { w := Widget{ - Count: azcore.NullValue(0).(*int), + Count: azcore.NullValue[*int](), } b, _ := json.Marshal(w) fmt.Println(string(b)) diff --git a/sdk/azcore/to/to.go b/sdk/azcore/to/to.go index 57a8d10ecc3f..e0e4817b90d1 100644 --- a/sdk/azcore/to/to.go +++ b/sdk/azcore/to/to.go @@ -6,102 +6,16 @@ package to -import "time" - -// BoolPtr returns a pointer to the provided bool. -func BoolPtr(b bool) *bool { - return &b -} - -// Float32Ptr returns a pointer to the provided float32. -func Float32Ptr(i float32) *float32 { - return &i -} - -// Float64Ptr returns a pointer to the provided float64. -func Float64Ptr(i float64) *float64 { - return &i -} - -// Int32Ptr returns a pointer to the provided int32. -func Int32Ptr(i int32) *int32 { - return &i -} - -// Int64Ptr returns a pointer to the provided int64. -func Int64Ptr(i int64) *int64 { - return &i -} - -// StringPtr returns a pointer to the provided string. -func StringPtr(s string) *string { - return &s -} - -// TimePtr returns a pointer to the provided time.Time. -func TimePtr(t time.Time) *time.Time { - return &t -} - -// Int32PtrArray returns an array of *int32 from the specified values. -func Int32PtrArray(vals ...int32) []*int32 { - arr := make([]*int32, len(vals)) - for i := range vals { - arr[i] = Int32Ptr(vals[i]) - } - return arr -} - -// Int64PtrArray returns an array of *int64 from the specified values. -func Int64PtrArray(vals ...int64) []*int64 { - arr := make([]*int64, len(vals)) - for i := range vals { - arr[i] = Int64Ptr(vals[i]) - } - return arr -} - -// Float32PtrArray returns an array of *float32 from the specified values. -func Float32PtrArray(vals ...float32) []*float32 { - arr := make([]*float32, len(vals)) - for i := range vals { - arr[i] = Float32Ptr(vals[i]) - } - return arr -} - -// Float64PtrArray returns an array of *float64 from the specified values. -func Float64PtrArray(vals ...float64) []*float64 { - arr := make([]*float64, len(vals)) - for i := range vals { - arr[i] = Float64Ptr(vals[i]) - } - return arr -} - -// BoolPtrArray returns an array of *bool from the specified values. -func BoolPtrArray(vals ...bool) []*bool { - arr := make([]*bool, len(vals)) - for i := range vals { - arr[i] = BoolPtr(vals[i]) - } - return arr -} - -// StringPtrArray returns an array of *string from the specified values. -func StringPtrArray(vals ...string) []*string { - arr := make([]*string, len(vals)) - for i := range vals { - arr[i] = StringPtr(vals[i]) - } - return arr +// Ptr returns a pointer to the provided value. +func Ptr[T any](v T) *T { + return &v } -// TimePtrArray returns an array of *time.Time from the specified values. -func TimePtrArray(vals ...time.Time) []*time.Time { - arr := make([]*time.Time, len(vals)) - for i := range vals { - arr[i] = TimePtr(vals[i]) +// SliceOfPtrs returns a slice of *T from the specified values. +func SliceOfPtrs[T any](vv ...T) []*T { + slc := make([]*T, len(vv)) + for i := range vv { + slc[i] = Ptr(vv[i]) } - return arr + return slc } diff --git a/sdk/azcore/to/to_test.go b/sdk/azcore/to/to_test.go index 177e9a48a3e4..175f52a31aea 100644 --- a/sdk/azcore/to/to_test.go +++ b/sdk/azcore/to/to_test.go @@ -7,16 +7,12 @@ package to import ( - "fmt" - "reflect" - "strconv" "testing" - "time" ) -func TestBoolPtr(t *testing.T) { +func TestPtr(t *testing.T) { b := true - pb := BoolPtr(b) + pb := Ptr(b) if pb == nil { t.Fatal("unexpected nil conversion") } @@ -25,168 +21,15 @@ func TestBoolPtr(t *testing.T) { } } -func TestFloat32Ptr(t *testing.T) { - f32 := float32(3.1415926) - pf32 := Float32Ptr(f32) - if pf32 == nil { - t.Fatal("unexpected nil conversion") - } - if *pf32 != f32 { - t.Fatalf("got %v, want %v", *pf32, f32) - } -} - -func TestFloat64Ptr(t *testing.T) { - f64 := float64(2.71828182845904) - pf64 := Float64Ptr(f64) - if pf64 == nil { - t.Fatal("unexpected nil conversion") - } - if *pf64 != f64 { - t.Fatalf("got %v, want %v", *pf64, f64) - } -} - -func TestInt32Ptr(t *testing.T) { - i32 := int32(123456789) - pi32 := Int32Ptr(i32) - if pi32 == nil { - t.Fatal("unexpected nil conversion") - } - if *pi32 != i32 { - t.Fatalf("got %v, want %v", *pi32, i32) - } -} - -func TestInt64Ptr(t *testing.T) { - i64 := int64(9876543210) - pi64 := Int64Ptr(i64) - if pi64 == nil { - t.Fatal("unexpected nil conversion") - } - if *pi64 != i64 { - t.Fatalf("got %v, want %v", *pi64, i64) - } -} - -func TestStringPtr(t *testing.T) { - s := "the string" - ps := StringPtr(s) - if ps == nil { - t.Fatal("unexpected nil conversion") - } - if *ps != s { - t.Fatalf("got %v, want %v", *ps, s) - } -} - -func TestTimePtr(t *testing.T) { - tt := time.Now() - pt := TimePtr(tt) - if pt == nil { - t.Fatal("unexpected nil conversion") - } - if *pt != tt { - t.Fatalf("got %v, want %v", *pt, tt) - } -} - -func TestInt32PtrArray(t *testing.T) { - arr := Int32PtrArray() - if len(arr) != 0 { - t.Fatal("expected zero length") - } - arr = Int32PtrArray(1, 2, 3, 4, 5) - for i, v := range arr { - if *v != int32(i+1) { - t.Fatal("values don't match") - } - } -} - -func TestInt64PtrArray(t *testing.T) { - arr := Int64PtrArray() +func TestSliceOfPtrs(t *testing.T) { + arr := SliceOfPtrs[int]() if len(arr) != 0 { t.Fatal("expected zero length") } - arr = Int64PtrArray(1, 2, 3, 4, 5) + arr = SliceOfPtrs(1, 2, 3, 4, 5) for i, v := range arr { - if *v != int64(i+1) { + if *v != i+1 { t.Fatal("values don't match") } } } - -func TestFloat32PtrArray(t *testing.T) { - arr := Float32PtrArray() - if len(arr) != 0 { - t.Fatal("expected zero length") - } - arr = Float32PtrArray(1.1, 2.2, 3.3, 4.4, 5.5) - for i, v := range arr { - f, err := strconv.ParseFloat(fmt.Sprintf("%d.%d", i+1, i+1), 32) - if err != nil { - t.Fatal(err) - } - if *v != float32(f) { - t.Fatal("values don't match") - } - } -} - -func TestFloat64PtrArray(t *testing.T) { - arr := Float64PtrArray() - if len(arr) != 0 { - t.Fatal("expected zero length") - } - arr = Float64PtrArray(1.1, 2.2, 3.3, 4.4, 5.5) - for i, v := range arr { - f, err := strconv.ParseFloat(fmt.Sprintf("%d.%d", i+1, i+1), 64) - if err != nil { - t.Fatal(err) - } - if *v != f { - t.Fatal("values don't match") - } - } -} - -func TestBoolPtrArray(t *testing.T) { - arr := BoolPtrArray() - if len(arr) != 0 { - t.Fatal("expected zero length") - } - arr = BoolPtrArray(true, false, true) - curr := true - for _, v := range arr { - if *v != curr { - t.Fatal("values don'p match") - } - curr = !curr - } -} - -func TestStringPtrArray(t *testing.T) { - arr := StringPtrArray() - if len(arr) != 0 { - t.Fatal("expected zero length") - } - arr = StringPtrArray("one", "", "three") - if !reflect.DeepEqual(arr, []*string{StringPtr("one"), StringPtr(""), StringPtr("three")}) { - t.Fatal("values don't match") - } -} - -func TestTimePtrArray(t *testing.T) { - arr := TimePtrArray() - if len(arr) != 0 { - t.Fatal("expected zero length") - } - t1 := time.Now() - t2 := time.Time{} - t3 := t1.Add(24 * time.Hour) - arr = TimePtrArray(t1, t2, t3) - if !reflect.DeepEqual(arr, []*time.Time{&t1, &t2, &t3}) { - t.Fatal("values don't match") - } -}