diff --git a/cmd/aws-application-networking-k8s/main.go b/cmd/aws-application-networking-k8s/main.go index 35d75648..21441474 100644 --- a/cmd/aws-application-networking-k8s/main.go +++ b/cmd/aws-application-networking-k8s/main.go @@ -78,12 +78,10 @@ func addOptionalCRDs(scheme *runtime.Scheme) { Version: "v1alpha1", } scheme.AddKnownTypes(awsGatewayControllerCRDGroupVersion, &anv1alpha1.TargetGroupPolicy{}, &anv1alpha1.TargetGroupPolicyList{}) - metav1.AddToGroupVersion(scheme, awsGatewayControllerCRDGroupVersion) - scheme.AddKnownTypes(awsGatewayControllerCRDGroupVersion, &anv1alpha1.VpcAssociationPolicy{}, &anv1alpha1.VpcAssociationPolicyList{}) - metav1.AddToGroupVersion(scheme, awsGatewayControllerCRDGroupVersion) - scheme.AddKnownTypes(awsGatewayControllerCRDGroupVersion, &anv1alpha1.AccessLogPolicy{}, &anv1alpha1.AccessLogPolicyList{}) + scheme.AddKnownTypes(awsGatewayControllerCRDGroupVersion, &anv1alpha1.IAMAuthPolicy{}, &anv1alpha1.IAMAuthPolicyList{}) + metav1.AddToGroupVersion(scheme, awsGatewayControllerCRDGroupVersion) } @@ -186,6 +184,11 @@ func main() { setupLog.Fatalf("accesslogpolicy controller setup failed: %s", err) } + err = controllers.RegisterIAMAuthPolicyController(ctrlLog.Named("iam-auth-policy"), cloud, latticeDataStore, finalizerManager, mgr) + if err != nil { + setupLog.Fatalf("iamauthpolicy controller setup failed: %s", err) + } + go latticestore.GetDefaultLatticeDataStore().ServeIntrospection() //+kubebuilder:scaffold:builder diff --git a/controllers/accesslogpolicy_controller.go b/controllers/accesslogpolicy_controller.go index f174c947..ab674a42 100644 --- a/controllers/accesslogpolicy_controller.go +++ b/controllers/accesslogpolicy_controller.go @@ -31,9 +31,13 @@ import ( pkg_builder "sigs.k8s.io/controller-runtime/pkg/builder" "sigs.k8s.io/controller-runtime/pkg/client" "sigs.k8s.io/controller-runtime/pkg/predicate" + "sigs.k8s.io/controller-runtime/pkg/source" gwv1alpha2 "sigs.k8s.io/gateway-api/apis/v1alpha2" gwv1beta1 "sigs.k8s.io/gateway-api/apis/v1beta1" + gwvv1beta1 "sigs.k8s.io/gateway-api/apis/v1beta1" + + "github.com/aws/aws-application-networking-k8s/controllers/eventhandlers" anv1alpha1 "github.com/aws/aws-application-networking-k8s/pkg/apis/applicationnetworking/v1alpha1" "github.com/aws/aws-application-networking-k8s/pkg/aws" "github.com/aws/aws-application-networking-k8s/pkg/aws/services" @@ -88,8 +92,15 @@ func RegisterAccessLogPolicyController( stackMarshaller: stackMarshaller, } + gatewayEventHandler := eventhandlers.NewGatewayEventHandler(log, mgrClient) + httpRouteEventHandler := eventhandlers.NewHTTPRouteEventHandler(log, mgrClient) + grpcRouteEventHandler := eventhandlers.NewGRPCRouteEventHandler(log, mgrClient) + builder := ctrl.NewControllerManagedBy(mgr). - For(&anv1alpha1.AccessLogPolicy{}, pkg_builder.WithPredicates(predicate.GenerationChangedPredicate{})) + For(&anv1alpha1.AccessLogPolicy{}, pkg_builder.WithPredicates(predicate.GenerationChangedPredicate{})). + Watches(&source.Kind{Type: &gwvv1beta1.Gateway{}}, gatewayEventHandler.MapToAccessLogPolicies(), pkg_builder.WithPredicates(predicate.GenerationChangedPredicate{})). + Watches(&source.Kind{Type: &gwvv1beta1.HTTPRoute{}}, httpRouteEventHandler.MapToAccessLogPolicies(), pkg_builder.WithPredicates(predicate.GenerationChangedPredicate{})). + Watches(&source.Kind{Type: &gwv1alpha2.GRPCRoute{}}, grpcRouteEventHandler.MapToAccessLogPolicies(), pkg_builder.WithPredicates(predicate.GenerationChangedPredicate{})) return builder.Complete(r) } diff --git a/controllers/eventhandlers/gateway.go b/controllers/eventhandlers/gateway.go index 7377c0b6..39612a00 100644 --- a/controllers/eventhandlers/gateway.go +++ b/controllers/eventhandlers/gateway.go @@ -4,6 +4,7 @@ import ( "context" "time" + "github.com/aws/aws-application-networking-k8s/pkg/k8s" "github.com/aws/aws-application-networking-k8s/pkg/model/core" "github.com/aws/aws-application-networking-k8s/pkg/utils/gwlog" @@ -21,6 +22,7 @@ import ( "github.com/aws/aws-application-networking-k8s/pkg/config" ) +// TODO: Remove `enqueueRequestsForGatewayEvent`, and use `gatewayEventHandler` only type enqueueRequestsForGatewayEvent struct { log gwlog.Logger client client.Client @@ -119,3 +121,42 @@ func (h *enqueueRequestsForGatewayEvent) enqueueImpactedRoutes(queue workqueue.R } } } + +type gatewayEventHandler struct { + log gwlog.Logger + client client.Client + mapper *resourceMapper +} + +func NewGatewayEventHandler(log gwlog.Logger, client client.Client) *gatewayEventHandler { + return &gatewayEventHandler{log: log, client: client, + mapper: &resourceMapper{log: log, client: client}} +} + +func (h *gatewayEventHandler) MapToIAMAuthPolicies() handler.EventHandler { + return handler.EnqueueRequestsFromMapFunc(func(obj client.Object) []reconcile.Request { + var requests []reconcile.Request + if gw, ok := obj.(*gateway_api.Gateway); ok { + policies := h.mapper.GatewayToIAMAuthPolicies(context.Background(), gw) + for _, p := range policies { + h.log.Infof("Gateway [%s/%s] resource change triggers IAMAuthPolicy [%s/%s] resource change", gw.Namespace, gw.Name, p.Namespace, p.Name) + requests = append(requests, reconcile.Request{NamespacedName: k8s.NamespacedName(p)}) + } + } + return requests + }) +} + +func (h *gatewayEventHandler) MapToAccessLogPolicies() handler.EventHandler { + return handler.EnqueueRequestsFromMapFunc(func(obj client.Object) []reconcile.Request { + var requests []reconcile.Request + if gw, ok := obj.(*gateway_api.Gateway); ok { + policies := h.mapper.GatewayToAccessLogPolicies(context.Background(), gw) + for _, p := range policies { + h.log.Infof("Gateway [%s/%s] resource change triggers AccessLogPolicy [%s/%s] resource change", gw.Namespace, gw.Name, p.Namespace, p.Name) + requests = append(requests, reconcile.Request{NamespacedName: k8s.NamespacedName(p)}) + } + } + return requests + }) +} diff --git a/controllers/eventhandlers/grpcroute.go b/controllers/eventhandlers/grpcroute.go new file mode 100644 index 00000000..f1bd9b1b --- /dev/null +++ b/controllers/eventhandlers/grpcroute.go @@ -0,0 +1,53 @@ +package eventhandlers + +import ( + "context" + + "sigs.k8s.io/gateway-api/apis/v1alpha2" + + "github.com/aws/aws-application-networking-k8s/pkg/k8s" + "github.com/aws/aws-application-networking-k8s/pkg/utils/gwlog" + + "sigs.k8s.io/controller-runtime/pkg/client" + "sigs.k8s.io/controller-runtime/pkg/handler" + "sigs.k8s.io/controller-runtime/pkg/reconcile" +) + +type grpcRouteEventHandler struct { + log gwlog.Logger + client client.Client + mapper *resourceMapper +} + +func NewGRPCRouteEventHandler(log gwlog.Logger, client client.Client) *grpcRouteEventHandler { + return &grpcRouteEventHandler{log: log, client: client, + mapper: &resourceMapper{log: log, client: client}} +} + +func (h *grpcRouteEventHandler) MapToIAMAuthPolicies() handler.EventHandler { + return handler.EnqueueRequestsFromMapFunc(func(obj client.Object) []reconcile.Request { + var requests []reconcile.Request + if route, ok := obj.(*v1alpha2.GRPCRoute); ok { + policies := h.mapper.GRPCRouteToIAMAuthPolicies(context.Background(), route) + for _, p := range policies { + h.log.Infof("GRPCRoute [%s/%s] resource change triggers IAMAuthPolicy [%s/%s] resource change", route.Namespace, route.Name, p.Namespace, p.Name) + requests = append(requests, reconcile.Request{NamespacedName: k8s.NamespacedName(p)}) + } + } + return requests + }) +} + +func (h *grpcRouteEventHandler) MapToAccessLogPolicies() handler.EventHandler { + return handler.EnqueueRequestsFromMapFunc(func(obj client.Object) []reconcile.Request { + var requests []reconcile.Request + if route, ok := obj.(*v1alpha2.GRPCRoute); ok { + policies := h.mapper.GRPCRouteToAccessLogPolicies(context.Background(), route) + for _, p := range policies { + h.log.Infof("GRPCRoute [%s/%s] resource change triggers AccessLogPolicy [%s/%s] resource change", route.Namespace, route.Name, p.Namespace, p.Name) + requests = append(requests, reconcile.Request{NamespacedName: k8s.NamespacedName(p)}) + } + } + return requests + }) +} diff --git a/controllers/eventhandlers/httproute.go b/controllers/eventhandlers/httproute.go new file mode 100644 index 00000000..f96c70b2 --- /dev/null +++ b/controllers/eventhandlers/httproute.go @@ -0,0 +1,53 @@ +package eventhandlers + +import ( + "context" + + "sigs.k8s.io/gateway-api/apis/v1beta1" + + "github.com/aws/aws-application-networking-k8s/pkg/k8s" + "github.com/aws/aws-application-networking-k8s/pkg/utils/gwlog" + + "sigs.k8s.io/controller-runtime/pkg/client" + "sigs.k8s.io/controller-runtime/pkg/handler" + "sigs.k8s.io/controller-runtime/pkg/reconcile" +) + +type httpRouteEventHandler struct { + log gwlog.Logger + client client.Client + mapper *resourceMapper +} + +func NewHTTPRouteEventHandler(log gwlog.Logger, client client.Client) *httpRouteEventHandler { + return &httpRouteEventHandler{log: log, client: client, + mapper: &resourceMapper{log: log, client: client}} +} + +func (h *httpRouteEventHandler) MapToIAMAuthPolicies() handler.EventHandler { + return handler.EnqueueRequestsFromMapFunc(func(obj client.Object) []reconcile.Request { + var requests []reconcile.Request + if route, ok := obj.(*v1beta1.HTTPRoute); ok { + policies := h.mapper.HTTPRouteToIAMAuthPolicies(context.Background(), route) + for _, p := range policies { + h.log.Infof("HTTPRoute [%s/%s] resource change triggers IAMAuthPolicy [%s/%s] resource change", route.Namespace, route.Name, p.Namespace, p.Name) + requests = append(requests, reconcile.Request{NamespacedName: k8s.NamespacedName(p)}) + } + } + return requests + }) +} + +func (h *httpRouteEventHandler) MapToAccessLogPolicies() handler.EventHandler { + return handler.EnqueueRequestsFromMapFunc(func(obj client.Object) []reconcile.Request { + var requests []reconcile.Request + if route, ok := obj.(*v1beta1.HTTPRoute); ok { + policies := h.mapper.HTTPRouteToAccessLogPolicies(context.Background(), route) + for _, p := range policies { + h.log.Infof("HTTPRoute [%s/%s] resource change triggers AccessLogPolicy [%s/%s] resource change", route.Namespace, route.Name, p.Namespace, p.Name) + requests = append(requests, reconcile.Request{NamespacedName: k8s.NamespacedName(p)}) + } + } + return requests + }) +} diff --git a/controllers/eventhandlers/mapper.go b/controllers/eventhandlers/mapper.go index 7ccf933e..de573ac8 100644 --- a/controllers/eventhandlers/mapper.go +++ b/controllers/eventhandlers/mapper.go @@ -12,6 +12,9 @@ import ( gateway_api "sigs.k8s.io/gateway-api/apis/v1beta1" mcs_api "sigs.k8s.io/mcs-api/pkg/apis/v1alpha1" + anv1alpha1 "github.com/aws/aws-application-networking-k8s/pkg/apis/applicationnetworking/v1alpha1" + "github.com/aws/aws-application-networking-k8s/pkg/gateway" + "github.com/aws/aws-application-networking-k8s/pkg/apis/applicationnetworking/v1alpha1" "github.com/aws/aws-application-networking-k8s/pkg/k8s" "github.com/aws/aws-application-networking-k8s/pkg/model/core" @@ -73,6 +76,36 @@ func (r *resourceMapper) VpcAssociationPolicyToGateway(ctx context.Context, vap return policyToTargetRefObj(r, ctx, vap, &gateway_api.Gateway{}) } +func (r *resourceMapper) GatewayToIAMAuthPolicies(ctx context.Context, gw *gateway_api.Gateway) []*anv1alpha1.IAMAuthPolicy { + policies, _ := gateway.GetAttachedPolicies(ctx, r.client, k8s.NamespacedName(gw), &anv1alpha1.IAMAuthPolicy{}) + return policies +} + +func (r *resourceMapper) GatewayToAccessLogPolicies(ctx context.Context, gw *gateway_api.Gateway) []*anv1alpha1.AccessLogPolicy { + policies, _ := gateway.GetAttachedPolicies(ctx, r.client, k8s.NamespacedName(gw), &anv1alpha1.AccessLogPolicy{}) + return policies +} + +func (r *resourceMapper) HTTPRouteToIAMAuthPolicies(ctx context.Context, route *gateway_api.HTTPRoute) []*anv1alpha1.IAMAuthPolicy { + policy, _ := gateway.GetAttachedPolicies(ctx, r.client, k8s.NamespacedName(route), &anv1alpha1.IAMAuthPolicy{}) + return policy +} + +func (r *resourceMapper) HTTPRouteToAccessLogPolicies(ctx context.Context, route *gateway_api.HTTPRoute) []*anv1alpha1.AccessLogPolicy { + policies, _ := gateway.GetAttachedPolicies(ctx, r.client, k8s.NamespacedName(route), &anv1alpha1.AccessLogPolicy{}) + return policies +} + +func (r *resourceMapper) GRPCRouteToIAMAuthPolicies(ctx context.Context, route *gateway_api_v1alpha2.GRPCRoute) []*anv1alpha1.IAMAuthPolicy { + policies, _ := gateway.GetAttachedPolicies(ctx, r.client, k8s.NamespacedName(route), &anv1alpha1.IAMAuthPolicy{}) + return policies +} + +func (r *resourceMapper) GRPCRouteToAccessLogPolicies(ctx context.Context, route *gateway_api_v1alpha2.GRPCRoute) []*anv1alpha1.AccessLogPolicy { + policies, _ := gateway.GetAttachedPolicies(ctx, r.client, k8s.NamespacedName(route), &anv1alpha1.AccessLogPolicy{}) + return policies +} + func policyToTargetRefObj[T client.Object](r *resourceMapper, ctx context.Context, policy core.Policy, retObj T) T { null := *new(T) if policy == nil { diff --git a/controllers/iamauthpolicy_controller.go b/controllers/iamauthpolicy_controller.go new file mode 100644 index 00000000..1e11f922 --- /dev/null +++ b/controllers/iamauthpolicy_controller.go @@ -0,0 +1,110 @@ +/* +Copyright 2021. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package controllers + +import ( + "context" + + "k8s.io/apimachinery/pkg/runtime" + "k8s.io/client-go/tools/record" + ctrl "sigs.k8s.io/controller-runtime" + pkg_builder "sigs.k8s.io/controller-runtime/pkg/builder" + "sigs.k8s.io/controller-runtime/pkg/client" + "sigs.k8s.io/controller-runtime/pkg/predicate" + "sigs.k8s.io/controller-runtime/pkg/source" + gwv1alpha2 "sigs.k8s.io/gateway-api/apis/v1alpha2" + gwvv1beta1 "sigs.k8s.io/gateway-api/apis/v1beta1" + + "github.com/aws/aws-application-networking-k8s/controllers/eventhandlers" + anv1alpha1 "github.com/aws/aws-application-networking-k8s/pkg/apis/applicationnetworking/v1alpha1" + "github.com/aws/aws-application-networking-k8s/pkg/aws" + "github.com/aws/aws-application-networking-k8s/pkg/deploy" + "github.com/aws/aws-application-networking-k8s/pkg/k8s" + "github.com/aws/aws-application-networking-k8s/pkg/latticestore" + lattice_runtime "github.com/aws/aws-application-networking-k8s/pkg/runtime" + "github.com/aws/aws-application-networking-k8s/pkg/utils/gwlog" +) + +const ( + authPolicyFinalizer = "iamauthpolicy.k8s.aws/resources" +) + +type authPolicyReconciler struct { + log gwlog.Logger + client client.Client + scheme *runtime.Scheme + finalizerManager k8s.FinalizerManager + eventRecorder record.EventRecorder + cloud aws.Cloud + dataStore *latticestore.LatticeDataStore + stackMarshaller deploy.StackMarshaller +} + +func RegisterIAMAuthPolicyController( + log gwlog.Logger, + cloud aws.Cloud, + dataStore *latticestore.LatticeDataStore, + finalizerManager k8s.FinalizerManager, + mgr ctrl.Manager, +) error { + k8sClient := mgr.GetClient() + scheme := mgr.GetScheme() + evtRec := mgr.GetEventRecorderFor("iamauthpolicy") + + stackMarshaller := deploy.NewDefaultStackMarshaller() + + r := &authPolicyReconciler{ + log: log, + client: k8sClient, + scheme: scheme, + finalizerManager: finalizerManager, + eventRecorder: evtRec, + cloud: cloud, + stackMarshaller: stackMarshaller, + dataStore: dataStore, + } + + gatewayEventHandler := eventhandlers.NewGatewayEventHandler(log, k8sClient) + httpRouteEventHandler := eventhandlers.NewHTTPRouteEventHandler(log, k8sClient) + grpcRouteEventHandler := eventhandlers.NewGRPCRouteEventHandler(log, k8sClient) + + builder := ctrl.NewControllerManagedBy(mgr). + For(&anv1alpha1.IAMAuthPolicy{}, pkg_builder.WithPredicates(predicate.GenerationChangedPredicate{})). + Watches(&source.Kind{Type: &gwvv1beta1.Gateway{}}, gatewayEventHandler.MapToIAMAuthPolicies(), pkg_builder.WithPredicates(predicate.GenerationChangedPredicate{})). + Watches(&source.Kind{Type: &gwvv1beta1.HTTPRoute{}}, httpRouteEventHandler.MapToIAMAuthPolicies(), pkg_builder.WithPredicates(predicate.GenerationChangedPredicate{})). + Watches(&source.Kind{Type: &gwv1alpha2.GRPCRoute{}}, grpcRouteEventHandler.MapToIAMAuthPolicies(), pkg_builder.WithPredicates(predicate.GenerationChangedPredicate{})) + return builder.Complete(r) +} + +func (r *authPolicyReconciler) Reconcile(ctx context.Context, req ctrl.Request) (ctrl.Result, error) { + r.log.Infow("reconcile", "name", req.Name) + recErr := r.reconcile(ctx, req) + res, retryErr := lattice_runtime.HandleReconcileError(recErr) + if res.RequeueAfter != 0 { + r.log.Infow("requeue request", "name", req.Name, "requeueAfter", res.RequeueAfter) + } else if res.Requeue { + r.log.Infow("requeue request", "name", req.Name) + } else if retryErr == nil { + r.log.Infow("reconciled", "name", req.Name) + } + return res, retryErr +} + +func (r *authPolicyReconciler) reconcile(ctx context.Context, req ctrl.Request) error { + //TODO: implement reconcile + return nil +} diff --git a/pkg/apis/applicationnetworking/v1alpha1/accesslogpolicy_types.go b/pkg/apis/applicationnetworking/v1alpha1/accesslogpolicy_types.go index a3ec9578..2ce38789 100644 --- a/pkg/apis/applicationnetworking/v1alpha1/accesslogpolicy_types.go +++ b/pkg/apis/applicationnetworking/v1alpha1/accesslogpolicy_types.go @@ -95,8 +95,8 @@ func (p *AccessLogPolicy) GetNamespacedName() types.NamespacedName { func (pl *AccessLogPolicyList) GetItems() []core.Policy { items := make([]core.Policy, len(pl.Items)) - for i, item := range pl.Items { - items[i] = &item + for i := range pl.Items { + items[i] = &pl.Items[i] } return items } diff --git a/pkg/apis/applicationnetworking/v1alpha1/authpolicy_types.go b/pkg/apis/applicationnetworking/v1alpha1/authpolicy_types.go index 73482d19..31b1be23 100644 --- a/pkg/apis/applicationnetworking/v1alpha1/authpolicy_types.go +++ b/pkg/apis/applicationnetworking/v1alpha1/authpolicy_types.go @@ -93,8 +93,8 @@ func (p *IAMAuthPolicy) GetNamespacedName() types.NamespacedName { func (pl *IAMAuthPolicyList) GetItems() []core.Policy { items := make([]core.Policy, len(pl.Items)) - for i, item := range pl.Items { - items[i] = &item + for i := range pl.Items { + items[i] = &pl.Items[i] } return items } diff --git a/pkg/apis/applicationnetworking/v1alpha1/targetgrouppolicy_types.go b/pkg/apis/applicationnetworking/v1alpha1/targetgrouppolicy_types.go index a9771374..8a2bd0bf 100644 --- a/pkg/apis/applicationnetworking/v1alpha1/targetgrouppolicy_types.go +++ b/pkg/apis/applicationnetworking/v1alpha1/targetgrouppolicy_types.go @@ -172,8 +172,8 @@ func (p *TargetGroupPolicy) SetStatusConditions(conditions []metav1.Condition) { func (pl *TargetGroupPolicyList) GetItems() []core.Policy { items := make([]core.Policy, len(pl.Items)) - for i, item := range pl.Items { - items[i] = &item + for i := range pl.Items { + items[i] = &pl.Items[i] } return items } diff --git a/pkg/apis/applicationnetworking/v1alpha1/vpcassociationpolicy_types.go b/pkg/apis/applicationnetworking/v1alpha1/vpcassociationpolicy_types.go index 580375b9..f58e1021 100644 --- a/pkg/apis/applicationnetworking/v1alpha1/vpcassociationpolicy_types.go +++ b/pkg/apis/applicationnetworking/v1alpha1/vpcassociationpolicy_types.go @@ -105,8 +105,8 @@ func (p *VpcAssociationPolicy) GetNamespacedName() types.NamespacedName { func (pl *VpcAssociationPolicyList) GetItems() []core.Policy { items := make([]core.Policy, len(pl.Items)) - for i, item := range pl.Items { - items[i] = &item + for i := range pl.Items { + items[i] = &pl.Items[i] } return items } diff --git a/pkg/gateway/model_build_service_network_test.go b/pkg/gateway/model_build_service_network_test.go index 908d27b9..ecde8a37 100644 --- a/pkg/gateway/model_build_service_network_test.go +++ b/pkg/gateway/model_build_service_network_test.go @@ -33,14 +33,14 @@ func Test_SNModelBuild(t *testing.T) { }, } tests := []struct { - name string - gw *gwv1beta1.Gateway - vpcAssociationPolicy *anv1alpha1.VpcAssociationPolicy - wantErr error - wantName string - wantNamespace string - wantIsDeleted bool - associateToVPC bool + name string + gw *gwv1beta1.Gateway + vpcAssociationPolicies []*anv1alpha1.VpcAssociationPolicy + wantErr error + wantName string + wantNamespace string + wantIsDeleted bool + associateToVPC bool }{ { name: "Adding SN in default namespace, no annotation on VPC association, associate to VPC by default", @@ -120,17 +120,32 @@ func Test_SNModelBuild(t *testing.T) { Finalizers: []string{"gateway.k8s.aws/resources"}, }, }, - vpcAssociationPolicy: &anv1alpha1.VpcAssociationPolicy{ - ObjectMeta: metav1.ObjectMeta{ - Name: "test-vpc-association-policy", + vpcAssociationPolicies: []*anv1alpha1.VpcAssociationPolicy{ + { + ObjectMeta: metav1.ObjectMeta{ + Name: "test-vpc-association-policy", + }, + Spec: anv1alpha1.VpcAssociationPolicySpec{ + TargetRef: &gwv1alpha2.PolicyTargetReference{ + Group: gwv1beta1.GroupName, + Kind: "Gateway", + Name: "gw1", + }, + SecurityGroupIds: []anv1alpha1.SecurityGroupId{"sg-123456", "sg-654321"}, + }, }, - Spec: anv1alpha1.VpcAssociationPolicySpec{ - TargetRef: &gwv1alpha2.PolicyTargetReference{ - Group: gwv1beta1.GroupName, - Kind: "Gateway", - Name: "gw1", + { + ObjectMeta: metav1.ObjectMeta{ + Name: "the-second-vpc-association-policy-will-not-take-effect", + }, + Spec: anv1alpha1.VpcAssociationPolicySpec{ + TargetRef: &gwv1alpha2.PolicyTargetReference{ + Group: gwv1beta1.GroupName, + Kind: "Gateway", + Name: "gw1", + }, + SecurityGroupIds: []anv1alpha1.SecurityGroupId{"sg-will-not-take-effect"}, }, - SecurityGroupIds: []anv1alpha1.SecurityGroupId{"sg-123456", "sg-654321"}, }, }, wantErr: nil, @@ -147,17 +162,19 @@ func Test_SNModelBuild(t *testing.T) { Finalizers: []string{"gateway.k8s.aws/resources"}, }, }, - vpcAssociationPolicy: &anv1alpha1.VpcAssociationPolicy{ - ObjectMeta: metav1.ObjectMeta{ - Name: "test-vpc-association-policy", - }, - Spec: anv1alpha1.VpcAssociationPolicySpec{ - TargetRef: &gwv1alpha2.PolicyTargetReference{ - Group: gwv1beta1.GroupName, - Kind: "Gateway", - Name: "gw1", + vpcAssociationPolicies: []*anv1alpha1.VpcAssociationPolicy{ + { + ObjectMeta: metav1.ObjectMeta{ + Name: "test-vpc-association-policy", + }, + Spec: anv1alpha1.VpcAssociationPolicySpec{ + TargetRef: &gwv1alpha2.PolicyTargetReference{ + Group: gwv1beta1.GroupName, + Kind: "Gateway", + Name: "gw1", + }, + AssociateWithVpc: &trueBool, }, - AssociateWithVpc: &trueBool, }, }, wantErr: nil, @@ -174,17 +191,19 @@ func Test_SNModelBuild(t *testing.T) { Finalizers: []string{"gateway.k8s.aws/resources"}, }, }, - vpcAssociationPolicy: &anv1alpha1.VpcAssociationPolicy{ - ObjectMeta: metav1.ObjectMeta{ - Name: "test-vpc-association-policy", - }, - Spec: anv1alpha1.VpcAssociationPolicySpec{ - TargetRef: &gwv1alpha2.PolicyTargetReference{ - Group: gwv1beta1.GroupName, - Kind: "Gateway", - Name: "gw1", + vpcAssociationPolicies: []*anv1alpha1.VpcAssociationPolicy{ + { + ObjectMeta: metav1.ObjectMeta{ + Name: "test-vpc-association-policy", + }, + Spec: anv1alpha1.VpcAssociationPolicySpec{ + TargetRef: &gwv1alpha2.PolicyTargetReference{ + Group: gwv1beta1.GroupName, + Kind: "Gateway", + Name: "gw1", + }, + AssociateWithVpc: &falseBool, }, - AssociateWithVpc: &falseBool, }, }, wantErr: nil, @@ -204,8 +223,8 @@ func Test_SNModelBuild(t *testing.T) { mockClient.EXPECT().List(ctx, gomock.Any(), gomock.Any()).DoAndReturn( func(ctx context.Context, policyList *anv1alpha1.VpcAssociationPolicyList, arg3 ...interface{}) error { policyList.Items = append(policyList.Items, notRelatedVpcAssociationPolicy) - if tt.vpcAssociationPolicy != nil { - policyList.Items = append(policyList.Items, *tt.vpcAssociationPolicy) + for _, p := range tt.vpcAssociationPolicies { + policyList.Items = append(policyList.Items, *p) } return nil }, @@ -220,8 +239,8 @@ func Test_SNModelBuild(t *testing.T) { assert.Equal(t, tt.wantNamespace, got.Spec.Namespace) assert.Equal(t, tt.wantIsDeleted, got.Spec.IsDeleted) assert.Equal(t, tt.associateToVPC, got.Spec.AssociateToVPC) - if tt.vpcAssociationPolicy != nil { - assert.Equal(t, securityGroupIdsToStringPointersSlice(tt.vpcAssociationPolicy.Spec.SecurityGroupIds), got.Spec.SecurityGroupIds) + if tt.vpcAssociationPolicies != nil { + assert.Equal(t, securityGroupIdsToStringPointersSlice(tt.vpcAssociationPolicies[0].Spec.SecurityGroupIds), got.Spec.SecurityGroupIds) } } }) diff --git a/pkg/gateway/model_build_servicenetwork.go b/pkg/gateway/model_build_servicenetwork.go index c60615ac..ba69a6d8 100644 --- a/pkg/gateway/model_build_servicenetwork.go +++ b/pkg/gateway/model_build_servicenetwork.go @@ -12,6 +12,7 @@ import ( "github.com/aws/aws-application-networking-k8s/pkg/k8s" "github.com/aws/aws-application-networking-k8s/pkg/model/core" model "github.com/aws/aws-application-networking-k8s/pkg/model/lattice" + "github.com/aws/aws-application-networking-k8s/pkg/utils/gwlog" ) const ( @@ -33,13 +34,21 @@ func NewServiceNetworkModelBuilder(client client.Client) *serviceNetworkModelBui } func (b *serviceNetworkModelBuilder) Build(ctx context.Context, gw *gateway_api.Gateway) (core.Stack, *model.ServiceNetwork, error) { stack := core.NewDefaultStack(core.StackID(k8s.NamespacedName(gw))) - vpcAssociationPolicy, err := GetAttachedPolicy(ctx, b.client, k8s.NamespacedName(gw), &anv1alpha1.VpcAssociationPolicy{}) + vpcAssociationPolicies, err := GetAttachedPolicies(ctx, b.client, k8s.NamespacedName(gw), &anv1alpha1.VpcAssociationPolicy{}) if err != nil { return nil, nil, err } + var vap *anv1alpha1.VpcAssociationPolicy + if len(vpcAssociationPolicies) >= 1 { + vap = vpcAssociationPolicies[0] + if len(vpcAssociationPolicies) > 1 { + gwlog.FallbackLogger.Errorf("More than one VpcAssociationPolicy is attached to the gateway [%s/%s], "+ + "only the first one VpcAssociationPolicy [%s/%s] will take effect, other VpcAssociationPolicies will be ignored", gw.Namespace, gw.Name, vap.Namespace, vap.Name) + } + } task := &serviceNetworkModelBuildTask{ gateway: gw, - vpcAssociationPolicy: vpcAssociationPolicy, + vpcAssociationPolicy: vap, stack: stack, } diff --git a/pkg/gateway/model_build_targetgroup.go b/pkg/gateway/model_build_targetgroup.go index 154f3636..d499812f 100644 --- a/pkg/gateway/model_build_targetgroup.go +++ b/pkg/gateway/model_build_targetgroup.go @@ -150,15 +150,19 @@ func (t *svcExportTargetGroupModelBuildTask) buildTargetGroupForServiceExportCre return nil, err } - tgp, err := GetAttachedPolicy(ctx, t.client, k8s.NamespacedName(t.serviceExport), &anv1alpha1.TargetGroupPolicy{}) + tgps, err := GetAttachedPolicies(ctx, t.client, k8s.NamespacedName(t.serviceExport), &anv1alpha1.TargetGroupPolicy{}) if err != nil { return nil, err } - protocol := "HTTP" protocolVersion := vpclattice.TargetGroupProtocolVersionHttp1 var healthCheckConfig *vpclattice.HealthCheckConfig - if tgp != nil { + if len(tgps) >= 1 { + tgp := tgps[0] + if len(tgps) > 1 { + t.log.Errorf("More than one TargetGroupPolicy is attached to the serviceExport %s, "+ + "only the first one TargetGroupPolicy [%s/%s] will take effect, other TargetGroupPolicies will be ignored", k8s.NamespacedName(t.serviceExport), tgp.Namespace, tgp.Name) + } if tgp.Spec.Protocol != nil { protocol = *tgp.Spec.Protocol } @@ -390,11 +394,18 @@ func (t *latticeServiceModelBuildTask) buildTargetGroupSpec( Namespace: namespace, Name: string(backendRef.Name()), } - tgp, err := GetAttachedPolicy(ctx, t.client, refObjNamespacedName, &anv1alpha1.TargetGroupPolicy{}) - + tgps, err := GetAttachedPolicies(ctx, t.client, refObjNamespacedName, &anv1alpha1.TargetGroupPolicy{}) if err != nil { return model.TargetGroupSpec{}, err } + var tgp *anv1alpha1.TargetGroupPolicy + if len(tgps) >= 1 { + tgp = tgps[0] + if len(tgps) > 1 { + t.log.Errorf("More than one TargetGroupPolicy is attached to the k8sService %s, "+ + "only the first one TargetGroupPolicy [%s/%s] will take effect, other TargetGroupPolicies will be ignored", refObjNamespacedName, tgp.Namespace, tgp.Name) + } + } protocol := "HTTP" protocolVersion := vpclattice.TargetGroupProtocolVersionHttp1 var healthCheckConfig *vpclattice.HealthCheckConfig diff --git a/pkg/gateway/utils.go b/pkg/gateway/utils.go index c9a10173..023f5019 100644 --- a/pkg/gateway/utils.go +++ b/pkg/gateway/utils.go @@ -6,6 +6,7 @@ import ( corev1 "k8s.io/api/core/v1" "k8s.io/apimachinery/pkg/api/meta" + "k8s.io/apimachinery/pkg/runtime/schema" "k8s.io/apimachinery/pkg/types" "sigs.k8s.io/controller-runtime/pkg/client" gwv1beta1 "sigs.k8s.io/gateway-api/apis/v1beta1" @@ -14,11 +15,11 @@ import ( "github.com/aws/aws-application-networking-k8s/pkg/model/core" ) -func GetAttachedPolicy[T core.Policy](ctx context.Context, k8sClient client.Client, refObjNamespacedName types.NamespacedName, policy T) (T, error) { - null := *new(T) - policyList, expectedTargetRefObjGroup, expectedTargetRefObjKind, err := policyTypeToPolicyListAndTargetRefGroupKind(policy) +func GetAttachedPolicies[T core.Policy](ctx context.Context, k8sClient client.Client, refObjNamespacedName types.NamespacedName, policy T) ([]T, error) { + var matchedPolicies []T + policyList, validTargetRefObjGroupKinds, err := policyTypeToPolicyListAndValidTargetRefObjGKs(policy) if err != nil { - return null, err + return matchedPolicies, err } err = k8sClient.List(ctx, policyList.(client.ObjectList), &client.ListOptions{ @@ -27,16 +28,19 @@ func GetAttachedPolicy[T core.Policy](ctx context.Context, k8sClient client.Clie if err != nil { if meta.IsNoMatchError(err) { // CRD does not exist - return null, nil + return matchedPolicies, nil } - return null, err + return matchedPolicies, err } for _, p := range policyList.GetItems() { targetRef := p.GetTargetRef() if targetRef == nil { continue } - groupKindMatch := targetRef.Group == expectedTargetRefObjGroup && targetRef.Kind == expectedTargetRefObjKind + _, groupKindMatch := validTargetRefObjGroupKinds[schema.GroupKind{ + Group: string(targetRef.Group), + Kind: string(targetRef.Kind), + }] nameMatch := string(targetRef.Name) == refObjNamespacedName.Name retrievedNamespace := p.GetNamespacedName().Namespace @@ -45,19 +49,35 @@ func GetAttachedPolicy[T core.Policy](ctx context.Context, k8sClient client.Clie } namespaceMatch := retrievedNamespace == refObjNamespacedName.Namespace if groupKindMatch && nameMatch && namespaceMatch { - return p.(T), nil + matchedPolicies = append(matchedPolicies, p.(T)) } } - return null, nil + return matchedPolicies, nil } -func policyTypeToPolicyListAndTargetRefGroupKind(policyType core.Policy) (core.PolicyList, gwv1beta1.Group, gwv1beta1.Kind, error) { +func policyTypeToPolicyListAndValidTargetRefObjGKs(policyType core.Policy) (core.PolicyList, map[schema.GroupKind]interface{}, error) { switch policyType.(type) { case *anv1alpha1.VpcAssociationPolicy: - return &anv1alpha1.VpcAssociationPolicyList{}, gwv1beta1.GroupName, "Gateway", nil + return &anv1alpha1.VpcAssociationPolicyList{}, map[schema.GroupKind]interface{}{ + {Group: gwv1beta1.GroupName, Kind: "Gateway"}: struct{}{}, + }, nil case *anv1alpha1.TargetGroupPolicy: - return &anv1alpha1.TargetGroupPolicyList{}, corev1.GroupName, "Service", nil + return &anv1alpha1.TargetGroupPolicyList{}, map[schema.GroupKind]interface{}{ + {Group: corev1.GroupName, Kind: "Service"}: struct{}{}, + }, nil + case *anv1alpha1.IAMAuthPolicy: + return &anv1alpha1.IAMAuthPolicyList{}, map[schema.GroupKind]interface{}{ + {Group: gwv1beta1.GroupName, Kind: "Gateway"}: struct{}{}, + {Group: gwv1beta1.GroupName, Kind: "HttpRoute"}: struct{}{}, + {Group: gwv1beta1.GroupName, Kind: "GRPCRoute"}: struct{}{}, + }, nil + case *anv1alpha1.AccessLogPolicy: + return &anv1alpha1.AccessLogPolicyList{}, map[schema.GroupKind]interface{}{ + {Group: gwv1beta1.GroupName, Kind: "Gateway"}: struct{}{}, + {Group: gwv1beta1.GroupName, Kind: "HttpRoute"}: struct{}{}, + {Group: gwv1beta1.GroupName, Kind: "GRPCRoute"}: struct{}{}, + }, nil default: - return nil, "", "", fmt.Errorf("unsupported policy type %T", policyType) + return nil, nil, fmt.Errorf("unknown policy type %T", policyType) } } diff --git a/pkg/gateway/utils_test.go b/pkg/gateway/utils_test.go index 283d2395..7b04037a 100644 --- a/pkg/gateway/utils_test.go +++ b/pkg/gateway/utils_test.go @@ -4,6 +4,7 @@ import ( "context" "testing" + "github.com/aws/aws-sdk-go/aws" "github.com/golang/mock/gomock" "github.com/stretchr/testify/assert" corev1 "k8s.io/api/core/v1" @@ -16,17 +17,17 @@ import ( "github.com/aws/aws-application-networking-k8s/pkg/model/core" ) -func Test_getAttachedPolicy(t *testing.T) { +func Test_GetAttachedPolicy(t *testing.T) { type args struct { refObjNamespacedName types.NamespacedName policy core.Policy } type testCase struct { - name string - args args - expectedK8sClientReturnedPolicy core.Policy - want core.Policy - expectPolicyNotFound bool + name string + args args + expectedK8sClientReturnedPolicies []core.Policy + want []core.Policy + expectPolicyNotFound bool } ns := "test-ns" typedNs := gwv1alpha2.Namespace(ns) @@ -49,12 +50,11 @@ func Test_getAttachedPolicy(t *testing.T) { ProtocolVersion: &protocolVersion, }, } - policyTargetRefSectionNil := targetGroupPolicyHappyPath.DeepCopyObject().(*anv1alpha1.TargetGroupPolicy) policyTargetRefSectionNil.Spec.TargetRef = nil - policyTargetRefKindWrong := targetGroupPolicyHappyPath.DeepCopyObject().(*anv1alpha1.TargetGroupPolicy) - policyTargetRefKindWrong.Spec.TargetRef.Kind = "ServiceImport" + tgpTargetRefKindWrong := targetGroupPolicyHappyPath.DeepCopyObject().(*anv1alpha1.TargetGroupPolicy) + tgpTargetRefKindWrong.Spec.TargetRef.Kind = "ServiceImport" notRelatedTargetGroupPolicy := targetGroupPolicyHappyPath.DeepCopyObject().(*anv1alpha1.TargetGroupPolicy) notRelatedTargetGroupPolicy.Spec.TargetRef.Name = "another-svc" @@ -78,6 +78,61 @@ func Test_getAttachedPolicy(t *testing.T) { notRelatedVpcAssociationPolicy := vpcAssociationPolicyHappyPath.DeepCopyObject().(*anv1alpha1.VpcAssociationPolicy) notRelatedVpcAssociationPolicy.Spec.TargetRef.Name = "another-gw" + iamAuthPolicyHappyPath := &anv1alpha1.IAMAuthPolicy{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-iam-auth-policy", + Namespace: ns, + }, + Spec: anv1alpha1.IAMAuthPolicySpec{ + TargetRef: &gwv1alpha2.PolicyTargetReference{ + Group: gwv1alpha2.GroupName, + Name: "gw-1", + Kind: "Gateway", + Namespace: &typedNs, + }, + Policy: "policy content", + }, + } + + notRelatedIAMAuthPolicy := iamAuthPolicyHappyPath.DeepCopyObject().(*anv1alpha1.IAMAuthPolicy) + notRelatedIAMAuthPolicy.Spec.TargetRef.Name = "another-gw" + + accessLogPolicy1HappyPath := &anv1alpha1.AccessLogPolicy{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-iam-auth-policy", + Namespace: ns, + }, + Spec: anv1alpha1.AccessLogPolicySpec{ + TargetRef: &gwv1alpha2.PolicyTargetReference{ + Group: gwv1alpha2.GroupName, + Name: "grpcroute-1", + Kind: "GRPCRoute", + Namespace: &typedNs, + }, + DestinationArn: aws.String("test-destination-arn-1"), + }, + } + accessLogPolicy2HappyPath := &anv1alpha1.AccessLogPolicy{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-iam-auth-policy", + Namespace: ns, + }, + Spec: anv1alpha1.AccessLogPolicySpec{ + TargetRef: &gwv1alpha2.PolicyTargetReference{ + Group: gwv1alpha2.GroupName, + Name: "grpcroute-1", + Kind: "GRPCRoute", + Namespace: &typedNs, + }, + DestinationArn: aws.String("test-destination-arn-2"), + }, + } + notRelatedAccessLogPolicy := accessLogPolicy1HappyPath.DeepCopyObject().(*anv1alpha1.AccessLogPolicy) + notRelatedAccessLogPolicy.Spec.TargetRef.Name = "another-grpcroute" + + alpTargetRefKindWrong := accessLogPolicy1HappyPath.DeepCopyObject().(*anv1alpha1.AccessLogPolicy) + alpTargetRefKindWrong.Spec.TargetRef.Kind = "Service" + var tests = []testCase{ { name: "Get k8sService attached TargetGroupPolicy, happy path", @@ -88,22 +143,61 @@ func Test_getAttachedPolicy(t *testing.T) { }, policy: &anv1alpha1.TargetGroupPolicy{}, }, - expectedK8sClientReturnedPolicy: targetGroupPolicyHappyPath, - want: targetGroupPolicyHappyPath, - expectPolicyNotFound: false, + expectedK8sClientReturnedPolicies: []core.Policy{targetGroupPolicyHappyPath}, + want: []core.Policy{targetGroupPolicyHappyPath}, + expectPolicyNotFound: false, + }, + { + name: "Get gateway attached IAMAuthPolicy, happy path", + args: args{ + refObjNamespacedName: types.NamespacedName{ + Namespace: ns, + Name: "gw-1", + }, + policy: &anv1alpha1.IAMAuthPolicy{}, + }, + expectedK8sClientReturnedPolicies: []core.Policy{iamAuthPolicyHappyPath}, + want: []core.Policy{iamAuthPolicyHappyPath}, + expectPolicyNotFound: false, + }, + { + name: "Get GRPCRoute attached AccessLogPolicy, happy path, be able to get all targetRef Matched policies", + args: args{ + refObjNamespacedName: types.NamespacedName{ + Namespace: ns, + Name: "grpcroute-1", + }, + policy: &anv1alpha1.AccessLogPolicy{}, + }, + expectedK8sClientReturnedPolicies: []core.Policy{accessLogPolicy1HappyPath, accessLogPolicy2HappyPath}, + want: []core.Policy{accessLogPolicy1HappyPath, accessLogPolicy2HappyPath}, + expectPolicyNotFound: false, }, { name: "Get k8sService attached TargetGroupPolicy, policy not found due to input refObj name mismatch", args: args{ refObjNamespacedName: types.NamespacedName{ Namespace: ns, - Name: "another-svc", + Name: "another-svc-1", }, policy: &anv1alpha1.TargetGroupPolicy{}, }, - want: nil, - expectedK8sClientReturnedPolicy: targetGroupPolicyHappyPath, - expectPolicyNotFound: true, + want: nil, + expectedK8sClientReturnedPolicies: []core.Policy{targetGroupPolicyHappyPath}, + expectPolicyNotFound: true, + }, + { + name: "Get gateway attached IAMAuthPolicy, policy not found due to input refObj name mismatch", + args: args{ + refObjNamespacedName: types.NamespacedName{ + Namespace: ns, + Name: "another-gw-1", + }, + policy: &anv1alpha1.IAMAuthPolicy{}, + }, + want: nil, + expectedK8sClientReturnedPolicies: []core.Policy{iamAuthPolicyHappyPath}, + expectPolicyNotFound: true, }, { name: "Get k8sService attached TargetGroupPolicy, policy not found due to cluster don't have matched policy", @@ -114,9 +208,9 @@ func Test_getAttachedPolicy(t *testing.T) { }, policy: &anv1alpha1.TargetGroupPolicy{}, }, - want: nil, - expectedK8sClientReturnedPolicy: nil, - expectPolicyNotFound: true, + want: nil, + expectedK8sClientReturnedPolicies: nil, + expectPolicyNotFound: true, }, { name: "Get k8sService attached TargetGroupPolicy, policy not found due to policy targetRef section is nil", @@ -127,9 +221,9 @@ func Test_getAttachedPolicy(t *testing.T) { }, policy: &anv1alpha1.TargetGroupPolicy{}, }, - expectedK8sClientReturnedPolicy: policyTargetRefSectionNil, - want: nil, - expectPolicyNotFound: true, + expectedK8sClientReturnedPolicies: []core.Policy{policyTargetRefSectionNil}, + want: nil, + expectPolicyNotFound: true, }, { name: "Get k8sService attached TargetGroupPolicy, policy not found due to policy targetRef Kind mismatch", @@ -140,9 +234,22 @@ func Test_getAttachedPolicy(t *testing.T) { }, policy: &anv1alpha1.TargetGroupPolicy{}, }, - expectedK8sClientReturnedPolicy: policyTargetRefKindWrong, - want: nil, - expectPolicyNotFound: true, + expectedK8sClientReturnedPolicies: []core.Policy{tgpTargetRefKindWrong}, + want: nil, + expectPolicyNotFound: true, + }, + { + name: "Get GRPCRoute attached AccessLogPolicy, policy not found due to policy targetRef Kind mismatch", + args: args{ + refObjNamespacedName: types.NamespacedName{ + Namespace: ns, + Name: "grpcroute-1", + }, + policy: &anv1alpha1.AccessLogPolicy{}, + }, + expectedK8sClientReturnedPolicies: []core.Policy{alpTargetRefKindWrong}, + want: nil, + expectPolicyNotFound: true, }, { name: "Get k8sGateway attached VpcAssociationPolicy, happy path", @@ -153,9 +260,9 @@ func Test_getAttachedPolicy(t *testing.T) { }, policy: &anv1alpha1.VpcAssociationPolicy{}, }, - expectedK8sClientReturnedPolicy: vpcAssociationPolicyHappyPath, - want: vpcAssociationPolicyHappyPath, - expectPolicyNotFound: false, + expectedK8sClientReturnedPolicies: []core.Policy{vpcAssociationPolicyHappyPath}, + want: []core.Policy{vpcAssociationPolicyHappyPath}, + expectPolicyNotFound: false, }, { name: "Get k8sGateway attached VpcAssociationPolicy, Not found", @@ -166,9 +273,9 @@ func Test_getAttachedPolicy(t *testing.T) { }, policy: &anv1alpha1.VpcAssociationPolicy{}, }, - expectedK8sClientReturnedPolicy: nil, - want: nil, - expectPolicyNotFound: true, + expectedK8sClientReturnedPolicies: nil, + want: nil, + expectPolicyNotFound: true, }, } c := gomock.NewController(t) @@ -182,9 +289,10 @@ func Test_getAttachedPolicy(t *testing.T) { mockK8sClient.EXPECT().List(ctx, gomock.Any(), gomock.Any()).DoAndReturn( func(ctx context.Context, list *anv1alpha1.TargetGroupPolicyList, arg3 ...interface{}) error { list.Items = append(list.Items, *notRelatedTargetGroupPolicy) - if tt.expectedK8sClientReturnedPolicy != nil { - policy := tt.expectedK8sClientReturnedPolicy.(*anv1alpha1.TargetGroupPolicy) + for _, p := range tt.expectedK8sClientReturnedPolicies { + policy := p.(*anv1alpha1.TargetGroupPolicy) list.Items = append(list.Items, *policy) + } return nil }) @@ -192,21 +300,44 @@ func Test_getAttachedPolicy(t *testing.T) { mockK8sClient.EXPECT().List(ctx, gomock.Any(), gomock.Any()).DoAndReturn( func(ctx context.Context, list *anv1alpha1.VpcAssociationPolicyList, arg3 ...interface{}) error { list.Items = append(list.Items, *notRelatedVpcAssociationPolicy) - if tt.expectedK8sClientReturnedPolicy != nil { - policy := tt.expectedK8sClientReturnedPolicy.(*anv1alpha1.VpcAssociationPolicy) + for _, p := range tt.expectedK8sClientReturnedPolicies { + policy := p.(*anv1alpha1.VpcAssociationPolicy) + list.Items = append(list.Items, *policy) + + } + return nil + }) + } else if _, ok := tt.args.policy.(*anv1alpha1.IAMAuthPolicy); ok { + mockK8sClient.EXPECT().List(ctx, gomock.Any(), gomock.Any()).DoAndReturn( + func(ctx context.Context, list *anv1alpha1.IAMAuthPolicyList, arg3 ...interface{}) error { + list.Items = append(list.Items, *notRelatedIAMAuthPolicy) + for _, p := range tt.expectedK8sClientReturnedPolicies { + policy := p.(*anv1alpha1.IAMAuthPolicy) + list.Items = append(list.Items, *policy) + } + return nil + }) + } else if _, ok := tt.args.policy.(*anv1alpha1.AccessLogPolicy); ok { + mockK8sClient.EXPECT().List(ctx, gomock.Any(), gomock.Any()).DoAndReturn( + func(ctx context.Context, list *anv1alpha1.AccessLogPolicyList, arg3 ...interface{}) error { + list.Items = append(list.Items, *notRelatedAccessLogPolicy) + for _, p := range tt.expectedK8sClientReturnedPolicies { + policy := p.(*anv1alpha1.AccessLogPolicy) list.Items = append(list.Items, *policy) } return nil }) + } else { + t.Errorf("unexpected policy type: %v", tt.args.policy) } - got, err := GetAttachedPolicy(ctx, mockK8sClient, tt.args.refObjNamespacedName, tt.args.policy) + got, err := GetAttachedPolicies(ctx, mockK8sClient, tt.args.refObjNamespacedName, tt.args.policy) if tt.expectPolicyNotFound { assert.Nil(t, err) assert.Nil(t, got) return } - assert.Equalf(t, tt.want, got, "GetAttachedPolicy(%v, %v, %v, %v)", ctx, mockK8sClient, tt.args.refObjNamespacedName, tt.args.policy) + assert.Equalf(t, tt.want, got, "GetAttachedPolicies(%v, %v, %v, %v)", ctx, mockK8sClient, tt.args.refObjNamespacedName, tt.args.policy) }) } }