From cf970d7b6085939aa4ad540cabc8307cfed5aa0a Mon Sep 17 00:00:00 2001
From: Matt Keeler <mjkeeler7@gmail.com>
Date: Mon, 21 Aug 2023 16:29:34 -0400
Subject: [PATCH] Reduce required type arguments for DecodedResource

---
 internal/catalog/exports.go                   |  2 +-
 .../controllers/failover/controller.go        | 20 +++++++--------
 .../mappers/failovermapper/failover_mapper.go |  2 +-
 .../failovermapper/failover_mapper_test.go    |  6 ++---
 .../internal/types/failover_policy_test.go    | 12 ++++-----
 internal/resource/decode.go                   | 25 +++++++------------
 internal/resource/decode_test.go              |  8 +++---
 internal/resource/resourcetest/decode.go      |  7 ++----
 8 files changed, 36 insertions(+), 46 deletions(-)

diff --git a/internal/catalog/exports.go b/internal/catalog/exports.go
index 4019fbeb51ac..566c2e2b6edf 100644
--- a/internal/catalog/exports.go
+++ b/internal/catalog/exports.go
@@ -118,7 +118,7 @@ func SimplifyFailoverPolicy(svc *pbcatalog.Service, failover *pbcatalog.Failover
 // FailoverPolicyMapper maintains the bidirectional tracking relationship of a
 // FailoverPolicy to the Services related to it.
 type FailoverPolicyMapper interface {
-	TrackFailover(failover *resource.DecodedResource[pbcatalog.FailoverPolicy, *pbcatalog.FailoverPolicy])
+	TrackFailover(failover *resource.DecodedResource[*pbcatalog.FailoverPolicy])
 	UntrackFailover(failoverID *pbresource.ID)
 	FailoverIDsByService(svcID *pbresource.ID) []*pbresource.ID
 }
diff --git a/internal/catalog/internal/controllers/failover/controller.go b/internal/catalog/internal/controllers/failover/controller.go
index ea6efa992d9f..9accb62aa447 100644
--- a/internal/catalog/internal/controllers/failover/controller.go
+++ b/internal/catalog/internal/controllers/failover/controller.go
@@ -20,7 +20,7 @@ type FailoverMapper interface {
 	// TrackFailover extracts all Service references from the provided
 	// FailoverPolicy and indexes them so that MapService can turn Service
 	// events into FailoverPolicy events properly.
-	TrackFailover(failover *resource.DecodedResource[pbcatalog.FailoverPolicy, *pbcatalog.FailoverPolicy])
+	TrackFailover(failover *resource.DecodedResource[*pbcatalog.FailoverPolicy])
 
 	// UntrackFailover forgets the links inserted by TrackFailover for the
 	// provided FailoverPolicyID.
@@ -86,7 +86,7 @@ func (r *failoverPolicyReconciler) Reconcile(ctx context.Context, rt controller.
 		rt.Logger.Error("error retrieving corresponding service", "error", err)
 		return err
 	}
-	destServices := make(map[resource.ReferenceKey]*resource.DecodedResource[pbcatalog.Service, *pbcatalog.Service])
+	destServices := make(map[resource.ReferenceKey]*resource.DecodedResource[*pbcatalog.Service])
 	if service != nil {
 		destServices[resource.NewReferenceKey(serviceID)] = service
 	}
@@ -148,18 +148,18 @@ func (r *failoverPolicyReconciler) Reconcile(ctx context.Context, rt controller.
 	return nil
 }
 
-func getFailoverPolicy(ctx context.Context, rt controller.Runtime, id *pbresource.ID) (*resource.DecodedResource[pbcatalog.FailoverPolicy, *pbcatalog.FailoverPolicy], error) {
-	return resource.GetDecodedResource[pbcatalog.FailoverPolicy, *pbcatalog.FailoverPolicy](ctx, rt.Client, id)
+func getFailoverPolicy(ctx context.Context, rt controller.Runtime, id *pbresource.ID) (*resource.DecodedResource[*pbcatalog.FailoverPolicy], error) {
+	return resource.GetDecodedResource[*pbcatalog.FailoverPolicy](ctx, rt.Client, id)
 }
 
-func getService(ctx context.Context, rt controller.Runtime, id *pbresource.ID) (*resource.DecodedResource[pbcatalog.Service, *pbcatalog.Service], error) {
-	return resource.GetDecodedResource[pbcatalog.Service, *pbcatalog.Service](ctx, rt.Client, id)
+func getService(ctx context.Context, rt controller.Runtime, id *pbresource.ID) (*resource.DecodedResource[*pbcatalog.Service], error) {
+	return resource.GetDecodedResource[*pbcatalog.Service](ctx, rt.Client, id)
 }
 
 func computeNewStatus(
-	failoverPolicy *resource.DecodedResource[pbcatalog.FailoverPolicy, *pbcatalog.FailoverPolicy],
-	service *resource.DecodedResource[pbcatalog.Service, *pbcatalog.Service],
-	destServices map[resource.ReferenceKey]*resource.DecodedResource[pbcatalog.Service, *pbcatalog.Service],
+	failoverPolicy *resource.DecodedResource[*pbcatalog.FailoverPolicy],
+	service *resource.DecodedResource[*pbcatalog.Service],
+	destServices map[resource.ReferenceKey]*resource.DecodedResource[*pbcatalog.Service],
 ) *pbresource.Status {
 	if service == nil {
 		return &pbresource.Status{
@@ -238,7 +238,7 @@ func computeNewStatus(
 
 func serviceHasPort(
 	dest *pbcatalog.FailoverDestination,
-	destServices map[resource.ReferenceKey]*resource.DecodedResource[pbcatalog.Service, *pbcatalog.Service],
+	destServices map[resource.ReferenceKey]*resource.DecodedResource[*pbcatalog.Service],
 ) *pbresource.Condition {
 	key := resource.NewReferenceKey(dest.Ref)
 	destService, ok := destServices[key]
diff --git a/internal/catalog/internal/mappers/failovermapper/failover_mapper.go b/internal/catalog/internal/mappers/failovermapper/failover_mapper.go
index 5c23a1bfe3a1..4ae6776cb66c 100644
--- a/internal/catalog/internal/mappers/failovermapper/failover_mapper.go
+++ b/internal/catalog/internal/mappers/failovermapper/failover_mapper.go
@@ -31,7 +31,7 @@ func New() *Mapper {
 // TrackFailover extracts all Service references from the provided
 // FailoverPolicy and indexes them so that MapService can turn Service events
 // into FailoverPolicy events properly.
-func (m *Mapper) TrackFailover(failover *resource.DecodedResource[pbcatalog.FailoverPolicy, *pbcatalog.FailoverPolicy]) {
+func (m *Mapper) TrackFailover(failover *resource.DecodedResource[*pbcatalog.FailoverPolicy]) {
 	destRefs := failover.Data.GetUnderlyingDestinationRefs()
 	destRefs = append(destRefs, &pbresource.Reference{
 		Type:    types.ServiceType,
diff --git a/internal/catalog/internal/mappers/failovermapper/failover_mapper_test.go b/internal/catalog/internal/mappers/failovermapper/failover_mapper_test.go
index 8a4ac2d7227d..048f444eca61 100644
--- a/internal/catalog/internal/mappers/failovermapper/failover_mapper_test.go
+++ b/internal/catalog/internal/mappers/failovermapper/failover_mapper_test.go
@@ -59,7 +59,7 @@ func TestMapper_Tracking(t *testing.T) {
 		}).
 		Build()
 	rtest.ValidateAndNormalize(t, registry, fail1)
-	failDec1 := rtest.MustDecode[pbcatalog.FailoverPolicy, *pbcatalog.FailoverPolicy](t, fail1)
+	failDec1 := rtest.MustDecode[*pbcatalog.FailoverPolicy](t, fail1)
 
 	fail2 := rtest.Resource(types.FailoverPolicyType, "www").
 		WithData(t, &pbcatalog.FailoverPolicy{
@@ -72,7 +72,7 @@ func TestMapper_Tracking(t *testing.T) {
 		}).
 		Build()
 	rtest.ValidateAndNormalize(t, registry, fail2)
-	failDec2 := rtest.MustDecode[pbcatalog.FailoverPolicy, *pbcatalog.FailoverPolicy](t, fail2)
+	failDec2 := rtest.MustDecode[*pbcatalog.FailoverPolicy](t, fail2)
 
 	fail1_updated := rtest.Resource(types.FailoverPolicyType, "api").
 		WithData(t, &pbcatalog.FailoverPolicy{
@@ -84,7 +84,7 @@ func TestMapper_Tracking(t *testing.T) {
 		}).
 		Build()
 	rtest.ValidateAndNormalize(t, registry, fail1_updated)
-	failDec1_updated := rtest.MustDecode[pbcatalog.FailoverPolicy, *pbcatalog.FailoverPolicy](t, fail1_updated)
+	failDec1_updated := rtest.MustDecode[*pbcatalog.FailoverPolicy](t, fail1_updated)
 
 	m := New()
 
diff --git a/internal/catalog/internal/types/failover_policy_test.go b/internal/catalog/internal/types/failover_policy_test.go
index 8f2ad9717298..41bfd3d82770 100644
--- a/internal/catalog/internal/types/failover_policy_test.go
+++ b/internal/catalog/internal/types/failover_policy_test.go
@@ -31,7 +31,7 @@ func TestMutateFailoverPolicy(t *testing.T) {
 
 		err := MutateFailoverPolicy(res)
 
-		got := resourcetest.MustDecode[pbcatalog.FailoverPolicy, *pbcatalog.FailoverPolicy](t, res)
+		got := resourcetest.MustDecode[*pbcatalog.FailoverPolicy](t, res)
 
 		if tc.expectErr == "" {
 			require.NoError(t, err)
@@ -162,13 +162,13 @@ func TestValidateFailoverPolicy(t *testing.T) {
 		require.NoError(t, MutateFailoverPolicy(res))
 
 		// Verify that mutate didn't actually change the object.
-		got := resourcetest.MustDecode[pbcatalog.FailoverPolicy, *pbcatalog.FailoverPolicy](t, res)
+		got := resourcetest.MustDecode[*pbcatalog.FailoverPolicy](t, res)
 		prototest.AssertDeepEqual(t, tc.failover, got.Data)
 
 		err := ValidateFailoverPolicy(res)
 
 		// Verify that validate didn't actually change the object.
-		got = resourcetest.MustDecode[pbcatalog.FailoverPolicy, *pbcatalog.FailoverPolicy](t, res)
+		got = resourcetest.MustDecode[*pbcatalog.FailoverPolicy](t, res)
 		prototest.AssertDeepEqual(t, tc.failover, got.Data)
 
 		if tc.expectErr == "" {
@@ -359,9 +359,9 @@ func TestSimplifyFailoverPolicy(t *testing.T) {
 		resourcetest.ValidateAndNormalize(t, registry, tc.failover)
 		resourcetest.ValidateAndNormalize(t, registry, tc.expect)
 
-		svc := resourcetest.MustDecode[pbcatalog.Service, *pbcatalog.Service](t, tc.svc)
-		failover := resourcetest.MustDecode[pbcatalog.FailoverPolicy, *pbcatalog.FailoverPolicy](t, tc.failover)
-		expect := resourcetest.MustDecode[pbcatalog.FailoverPolicy, *pbcatalog.FailoverPolicy](t, tc.expect)
+		svc := resourcetest.MustDecode[*pbcatalog.Service](t, tc.svc)
+		failover := resourcetest.MustDecode[*pbcatalog.FailoverPolicy](t, tc.failover)
+		expect := resourcetest.MustDecode[*pbcatalog.FailoverPolicy](t, tc.expect)
 
 		inputFailoverCopy := proto.Clone(failover.Data).(*pbcatalog.FailoverPolicy)
 
diff --git a/internal/resource/decode.go b/internal/resource/decode.go
index 7b1fb7b36450..c610898ca3c2 100644
--- a/internal/resource/decode.go
+++ b/internal/resource/decode.go
@@ -15,27 +15,23 @@ import (
 
 // DecodedResource is a generic holder to contain an original Resource and its
 // decoded contents.
-type DecodedResource[V any, PV interface {
-	proto.Message
-	*V
-}] struct {
+type DecodedResource[T proto.Message] struct {
 	Resource *pbresource.Resource
-	Data     PV
+	Data     T
 }
 
 // Decode will generically decode the provided resource into a 2-field
 // structure that holds onto the original Resource and the decoded contents.
 //
 // Returns an ErrDataParse on unmarshalling errors.
-func Decode[V any, PV interface {
-	proto.Message
-	*V
-}](res *pbresource.Resource) (*DecodedResource[V, PV], error) {
-	data := PV(new(V))
+func Decode[T proto.Message](res *pbresource.Resource) (*DecodedResource[T], error) {
+	var zero T
+	data := zero.ProtoReflect().New().Interface().(T)
+
 	if err := res.Data.UnmarshalTo(data); err != nil {
 		return nil, NewErrDataParse(data, err)
 	}
-	return &DecodedResource[V, PV]{
+	return &DecodedResource[T]{
 		Resource: res,
 		Data:     data,
 	}, nil
@@ -43,10 +39,7 @@ func Decode[V any, PV interface {
 
 // GetDecodedResource will generically read the requested resource using the
 // client and either return nil on a NotFound or decode the response value.
-func GetDecodedResource[V any, PV interface {
-	proto.Message
-	*V
-}](ctx context.Context, client pbresource.ResourceServiceClient, id *pbresource.ID) (*DecodedResource[V, PV], error) {
+func GetDecodedResource[T proto.Message](ctx context.Context, client pbresource.ResourceServiceClient, id *pbresource.ID) (*DecodedResource[T], error) {
 	rsp, err := client.Read(ctx, &pbresource.ReadRequest{Id: id})
 	switch {
 	case status.Code(err) == codes.NotFound:
@@ -55,5 +48,5 @@ func GetDecodedResource[V any, PV interface {
 		return nil, err
 	}
 
-	return Decode[V, PV](rsp.Resource)
+	return Decode[T](rsp.Resource)
 }
diff --git a/internal/resource/decode_test.go b/internal/resource/decode_test.go
index 31ebe47c64bf..17c1bd7f1b07 100644
--- a/internal/resource/decode_test.go
+++ b/internal/resource/decode_test.go
@@ -34,7 +34,7 @@ func TestGetDecodedResource(t *testing.T) {
 	}
 
 	testutil.RunStep(t, "not found", func(t *testing.T) {
-		got, err := resource.GetDecodedResource[pbdemo.Artist, *pbdemo.Artist](ctx, client, babypantsID)
+		got, err := resource.GetDecodedResource[*pbdemo.Artist](ctx, client, babypantsID)
 		require.NoError(t, err)
 		require.Nil(t, got)
 	})
@@ -47,7 +47,7 @@ func TestGetDecodedResource(t *testing.T) {
 			WithData(t, data).
 			Write(t, client)
 
-		got, err := resource.GetDecodedResource[pbdemo.Artist, *pbdemo.Artist](ctx, client, babypantsID)
+		got, err := resource.GetDecodedResource[*pbdemo.Artist](ctx, client, babypantsID)
 		require.NoError(t, err)
 		require.NotNil(t, got)
 
@@ -84,7 +84,7 @@ func TestDecode(t *testing.T) {
 			},
 		}
 
-		dec, err := resource.Decode[pbdemo.Artist, *pbdemo.Artist](foo)
+		dec, err := resource.Decode[*pbdemo.Artist](foo)
 		require.NoError(t, err)
 
 		prototest.AssertDeepEqual(t, foo, dec.Resource)
@@ -107,7 +107,7 @@ func TestDecode(t *testing.T) {
 			},
 		}
 
-		_, err := resource.Decode[pbdemo.Artist, *pbdemo.Artist](foo)
+		_, err := resource.Decode[*pbdemo.Artist](foo)
 		require.Error(t, err)
 	})
 }
diff --git a/internal/resource/resourcetest/decode.go b/internal/resource/resourcetest/decode.go
index 077bbc0dd514..d68fff865517 100644
--- a/internal/resource/resourcetest/decode.go
+++ b/internal/resource/resourcetest/decode.go
@@ -13,11 +13,8 @@ import (
 	"github.com/hashicorp/consul/proto-public/pbresource"
 )
 
-func MustDecode[V any, PV interface {
-	proto.Message
-	*V
-}](t *testing.T, res *pbresource.Resource) *resource.DecodedResource[V, PV] {
-	dec, err := resource.Decode[V, PV](res)
+func MustDecode[T proto.Message](t *testing.T, res *pbresource.Resource) *resource.DecodedResource[T] {
+	dec, err := resource.Decode[T](res)
 	require.NoError(t, err)
 	return dec
 }