Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reduce required type arguments for DecodedResource #18540

Merged
merged 1 commit into from
Aug 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion internal/catalog/exports.go
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ func SimplifyFailoverPolicy(svc *pbcatalog.Service, failover *pbcatalog.Failover
// FailoverPolicyMapper maintains the bidirectional tracking relationship of a
// FailoverPolicy to the Services related to it.
type FailoverPolicyMapper interface {
TrackFailover(failover *resource.DecodedResource[pbcatalog.FailoverPolicy, *pbcatalog.FailoverPolicy])
TrackFailover(failover *resource.DecodedResource[*pbcatalog.FailoverPolicy])
UntrackFailover(failoverID *pbresource.ID)
FailoverIDsByService(svcID *pbresource.ID) []*pbresource.ID
}
Expand Down
20 changes: 10 additions & 10 deletions internal/catalog/internal/controllers/failover/controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ type FailoverMapper interface {
// TrackFailover extracts all Service references from the provided
// FailoverPolicy and indexes them so that MapService can turn Service
// events into FailoverPolicy events properly.
TrackFailover(failover *resource.DecodedResource[pbcatalog.FailoverPolicy, *pbcatalog.FailoverPolicy])
TrackFailover(failover *resource.DecodedResource[*pbcatalog.FailoverPolicy])

// UntrackFailover forgets the links inserted by TrackFailover for the
// provided FailoverPolicyID.
Expand Down Expand Up @@ -86,7 +86,7 @@ func (r *failoverPolicyReconciler) Reconcile(ctx context.Context, rt controller.
rt.Logger.Error("error retrieving corresponding service", "error", err)
return err
}
destServices := make(map[resource.ReferenceKey]*resource.DecodedResource[pbcatalog.Service, *pbcatalog.Service])
destServices := make(map[resource.ReferenceKey]*resource.DecodedResource[*pbcatalog.Service])
if service != nil {
destServices[resource.NewReferenceKey(serviceID)] = service
}
Expand Down Expand Up @@ -148,18 +148,18 @@ func (r *failoverPolicyReconciler) Reconcile(ctx context.Context, rt controller.
return nil
}

func getFailoverPolicy(ctx context.Context, rt controller.Runtime, id *pbresource.ID) (*resource.DecodedResource[pbcatalog.FailoverPolicy, *pbcatalog.FailoverPolicy], error) {
return resource.GetDecodedResource[pbcatalog.FailoverPolicy, *pbcatalog.FailoverPolicy](ctx, rt.Client, id)
func getFailoverPolicy(ctx context.Context, rt controller.Runtime, id *pbresource.ID) (*resource.DecodedResource[*pbcatalog.FailoverPolicy], error) {
return resource.GetDecodedResource[*pbcatalog.FailoverPolicy](ctx, rt.Client, id)
}

func getService(ctx context.Context, rt controller.Runtime, id *pbresource.ID) (*resource.DecodedResource[pbcatalog.Service, *pbcatalog.Service], error) {
return resource.GetDecodedResource[pbcatalog.Service, *pbcatalog.Service](ctx, rt.Client, id)
func getService(ctx context.Context, rt controller.Runtime, id *pbresource.ID) (*resource.DecodedResource[*pbcatalog.Service], error) {
return resource.GetDecodedResource[*pbcatalog.Service](ctx, rt.Client, id)
}

func computeNewStatus(
failoverPolicy *resource.DecodedResource[pbcatalog.FailoverPolicy, *pbcatalog.FailoverPolicy],
service *resource.DecodedResource[pbcatalog.Service, *pbcatalog.Service],
destServices map[resource.ReferenceKey]*resource.DecodedResource[pbcatalog.Service, *pbcatalog.Service],
failoverPolicy *resource.DecodedResource[*pbcatalog.FailoverPolicy],
service *resource.DecodedResource[*pbcatalog.Service],
destServices map[resource.ReferenceKey]*resource.DecodedResource[*pbcatalog.Service],
) *pbresource.Status {
if service == nil {
return &pbresource.Status{
Expand Down Expand Up @@ -238,7 +238,7 @@ func computeNewStatus(

func serviceHasPort(
dest *pbcatalog.FailoverDestination,
destServices map[resource.ReferenceKey]*resource.DecodedResource[pbcatalog.Service, *pbcatalog.Service],
destServices map[resource.ReferenceKey]*resource.DecodedResource[*pbcatalog.Service],
) *pbresource.Condition {
key := resource.NewReferenceKey(dest.Ref)
destService, ok := destServices[key]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ func New() *Mapper {
// TrackFailover extracts all Service references from the provided
// FailoverPolicy and indexes them so that MapService can turn Service events
// into FailoverPolicy events properly.
func (m *Mapper) TrackFailover(failover *resource.DecodedResource[pbcatalog.FailoverPolicy, *pbcatalog.FailoverPolicy]) {
func (m *Mapper) TrackFailover(failover *resource.DecodedResource[*pbcatalog.FailoverPolicy]) {
destRefs := failover.Data.GetUnderlyingDestinationRefs()
destRefs = append(destRefs, &pbresource.Reference{
Type: types.ServiceType,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ func TestMapper_Tracking(t *testing.T) {
}).
Build()
rtest.ValidateAndNormalize(t, registry, fail1)
failDec1 := rtest.MustDecode[pbcatalog.FailoverPolicy, *pbcatalog.FailoverPolicy](t, fail1)
failDec1 := rtest.MustDecode[*pbcatalog.FailoverPolicy](t, fail1)

fail2 := rtest.Resource(types.FailoverPolicyType, "www").
WithData(t, &pbcatalog.FailoverPolicy{
Expand All @@ -72,7 +72,7 @@ func TestMapper_Tracking(t *testing.T) {
}).
Build()
rtest.ValidateAndNormalize(t, registry, fail2)
failDec2 := rtest.MustDecode[pbcatalog.FailoverPolicy, *pbcatalog.FailoverPolicy](t, fail2)
failDec2 := rtest.MustDecode[*pbcatalog.FailoverPolicy](t, fail2)

fail1_updated := rtest.Resource(types.FailoverPolicyType, "api").
WithData(t, &pbcatalog.FailoverPolicy{
Expand All @@ -84,7 +84,7 @@ func TestMapper_Tracking(t *testing.T) {
}).
Build()
rtest.ValidateAndNormalize(t, registry, fail1_updated)
failDec1_updated := rtest.MustDecode[pbcatalog.FailoverPolicy, *pbcatalog.FailoverPolicy](t, fail1_updated)
failDec1_updated := rtest.MustDecode[*pbcatalog.FailoverPolicy](t, fail1_updated)

m := New()

Expand Down
12 changes: 6 additions & 6 deletions internal/catalog/internal/types/failover_policy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ func TestMutateFailoverPolicy(t *testing.T) {

err := MutateFailoverPolicy(res)

got := resourcetest.MustDecode[pbcatalog.FailoverPolicy, *pbcatalog.FailoverPolicy](t, res)
got := resourcetest.MustDecode[*pbcatalog.FailoverPolicy](t, res)

if tc.expectErr == "" {
require.NoError(t, err)
Expand Down Expand Up @@ -162,13 +162,13 @@ func TestValidateFailoverPolicy(t *testing.T) {
require.NoError(t, MutateFailoverPolicy(res))

// Verify that mutate didn't actually change the object.
got := resourcetest.MustDecode[pbcatalog.FailoverPolicy, *pbcatalog.FailoverPolicy](t, res)
got := resourcetest.MustDecode[*pbcatalog.FailoverPolicy](t, res)
prototest.AssertDeepEqual(t, tc.failover, got.Data)

err := ValidateFailoverPolicy(res)

// Verify that validate didn't actually change the object.
got = resourcetest.MustDecode[pbcatalog.FailoverPolicy, *pbcatalog.FailoverPolicy](t, res)
got = resourcetest.MustDecode[*pbcatalog.FailoverPolicy](t, res)
prototest.AssertDeepEqual(t, tc.failover, got.Data)

if tc.expectErr == "" {
Expand Down Expand Up @@ -359,9 +359,9 @@ func TestSimplifyFailoverPolicy(t *testing.T) {
resourcetest.ValidateAndNormalize(t, registry, tc.failover)
resourcetest.ValidateAndNormalize(t, registry, tc.expect)

svc := resourcetest.MustDecode[pbcatalog.Service, *pbcatalog.Service](t, tc.svc)
failover := resourcetest.MustDecode[pbcatalog.FailoverPolicy, *pbcatalog.FailoverPolicy](t, tc.failover)
expect := resourcetest.MustDecode[pbcatalog.FailoverPolicy, *pbcatalog.FailoverPolicy](t, tc.expect)
svc := resourcetest.MustDecode[*pbcatalog.Service](t, tc.svc)
failover := resourcetest.MustDecode[*pbcatalog.FailoverPolicy](t, tc.failover)
expect := resourcetest.MustDecode[*pbcatalog.FailoverPolicy](t, tc.expect)

inputFailoverCopy := proto.Clone(failover.Data).(*pbcatalog.FailoverPolicy)

Expand Down
25 changes: 9 additions & 16 deletions internal/resource/decode.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,38 +15,31 @@ import (

// DecodedResource is a generic holder to contain an original Resource and its
// decoded contents.
type DecodedResource[V any, PV interface {
proto.Message
*V
}] struct {
type DecodedResource[T proto.Message] struct {
Resource *pbresource.Resource
Data PV
Data T
}

// Decode will generically decode the provided resource into a 2-field
// structure that holds onto the original Resource and the decoded contents.
//
// Returns an ErrDataParse on unmarshalling errors.
func Decode[V any, PV interface {
proto.Message
*V
}](res *pbresource.Resource) (*DecodedResource[V, PV], error) {
data := PV(new(V))
func Decode[T proto.Message](res *pbresource.Resource) (*DecodedResource[T], error) {
var zero T
data := zero.ProtoReflect().New().Interface().(T)

if err := res.Data.UnmarshalTo(data); err != nil {
return nil, NewErrDataParse(data, err)
}
return &DecodedResource[V, PV]{
return &DecodedResource[T]{
Resource: res,
Data: data,
}, nil
}

// GetDecodedResource will generically read the requested resource using the
// client and either return nil on a NotFound or decode the response value.
func GetDecodedResource[V any, PV interface {
proto.Message
*V
}](ctx context.Context, client pbresource.ResourceServiceClient, id *pbresource.ID) (*DecodedResource[V, PV], error) {
func GetDecodedResource[T proto.Message](ctx context.Context, client pbresource.ResourceServiceClient, id *pbresource.ID) (*DecodedResource[T], error) {
rsp, err := client.Read(ctx, &pbresource.ReadRequest{Id: id})
switch {
case status.Code(err) == codes.NotFound:
Expand All @@ -55,5 +48,5 @@ func GetDecodedResource[V any, PV interface {
return nil, err
}

return Decode[V, PV](rsp.Resource)
return Decode[T](rsp.Resource)
}
8 changes: 4 additions & 4 deletions internal/resource/decode_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ func TestGetDecodedResource(t *testing.T) {
}

testutil.RunStep(t, "not found", func(t *testing.T) {
got, err := resource.GetDecodedResource[pbdemo.Artist, *pbdemo.Artist](ctx, client, babypantsID)
got, err := resource.GetDecodedResource[*pbdemo.Artist](ctx, client, babypantsID)
require.NoError(t, err)
require.Nil(t, got)
})
Expand All @@ -47,7 +47,7 @@ func TestGetDecodedResource(t *testing.T) {
WithData(t, data).
Write(t, client)

got, err := resource.GetDecodedResource[pbdemo.Artist, *pbdemo.Artist](ctx, client, babypantsID)
got, err := resource.GetDecodedResource[*pbdemo.Artist](ctx, client, babypantsID)
require.NoError(t, err)
require.NotNil(t, got)

Expand Down Expand Up @@ -84,7 +84,7 @@ func TestDecode(t *testing.T) {
},
}

dec, err := resource.Decode[pbdemo.Artist, *pbdemo.Artist](foo)
dec, err := resource.Decode[*pbdemo.Artist](foo)
require.NoError(t, err)

prototest.AssertDeepEqual(t, foo, dec.Resource)
Expand All @@ -107,7 +107,7 @@ func TestDecode(t *testing.T) {
},
}

_, err := resource.Decode[pbdemo.Artist, *pbdemo.Artist](foo)
_, err := resource.Decode[*pbdemo.Artist](foo)
require.Error(t, err)
})
}
7 changes: 2 additions & 5 deletions internal/resource/resourcetest/decode.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,8 @@ import (
"github.com/hashicorp/consul/proto-public/pbresource"
)

func MustDecode[V any, PV interface {
proto.Message
*V
}](t *testing.T, res *pbresource.Resource) *resource.DecodedResource[V, PV] {
dec, err := resource.Decode[V, PV](res)
func MustDecode[T proto.Message](t *testing.T, res *pbresource.Resource) *resource.DecodedResource[T] {
dec, err := resource.Decode[T](res)
require.NoError(t, err)
return dec
}