Skip to content

Commit

Permalink
Kuadrantgh-628 fix deletion of dnsrecord and certificates on deletio…
Browse files Browse the repository at this point in the history
…n of gateway target for dnspolicy and tlspolicy
  • Loading branch information
laurafitzgerald committed Oct 20, 2023
1 parent 6965b6a commit 2e1971e
Show file tree
Hide file tree
Showing 7 changed files with 107 additions and 39 deletions.
44 changes: 30 additions & 14 deletions pkg/controllers/dnspolicy/dns_helper.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,11 +80,27 @@ func findMatchingManagedZone(originalHost, host string, zones []v1alpha1.Managed
}

func commonDNSRecordLabels(gwKey, apKey client.ObjectKey) map[string]string {
common := map[string]string{}
for k, v := range policyDNSRecordLabels(apKey) {
common[k] = v
}
for k, v := range gatewayDNSRecordLabels(gwKey) {
common[k] = v
}
return common
}

func policyDNSRecordLabels(apKey client.ObjectKey) map[string]string {
return map[string]string{
DNSPolicyBackRefAnnotation: apKey.Name,
fmt.Sprintf("%s-namespace", DNSPolicyBackRefAnnotation): apKey.Namespace,
LabelGatewayNSRef: gwKey.Namespace,
LabelGatewayReference: gwKey.Name,
}
}

func gatewayDNSRecordLabels(gwKey client.ObjectKey) map[string]string {
return map[string]string{
LabelGatewayNSRef: gwKey.Namespace,
LabelGatewayReference: gwKey.Name,
}
}

Expand Down Expand Up @@ -282,13 +298,13 @@ func createOrUpdateEndpoint(dnsName string, targets v1alpha1.Targets, recordType
}

// removeDNSForDeletedListeners remove any DNSRecords that are associated with listeners that no longer exist in this gateway
func (r *dnsHelper) removeDNSForDeletedListeners(ctx context.Context, upstreamGateway *gatewayv1beta1.Gateway) error {
func (dh *dnsHelper) removeDNSForDeletedListeners(ctx context.Context, upstreamGateway *gatewayv1beta1.Gateway) error {
dnsList := &v1alpha1.DNSRecordList{}
//List all dns records that belong to this gateway
labelSelector := &client.MatchingLabels{
LabelGatewayReference: upstreamGateway.Name,
}
if err := r.List(ctx, dnsList, labelSelector, &client.ListOptions{Namespace: upstreamGateway.Namespace}); err != nil {
if err := dh.List(ctx, dnsList, labelSelector, &client.ListOptions{Namespace: upstreamGateway.Namespace}); err != nil {
return err
}

Expand All @@ -301,7 +317,7 @@ func (r *dnsHelper) removeDNSForDeletedListeners(ctx context.Context, upstreamGa
}
}
if !listenerExists {
if err := r.Delete(ctx, &dns, &client.DeleteOptions{}); client.IgnoreNotFound(err) != nil {
if err := dh.Delete(ctx, &dns, &client.DeleteOptions{}); client.IgnoreNotFound(err) != nil {
return err
}
}
Expand All @@ -310,9 +326,9 @@ func (r *dnsHelper) removeDNSForDeletedListeners(ctx context.Context, upstreamGa

}

func (r *dnsHelper) getManagedZoneForListener(ctx context.Context, ns string, listener gatewayv1beta1.Listener) (*v1alpha1.ManagedZone, error) {
func (dh *dnsHelper) getManagedZoneForListener(ctx context.Context, ns string, listener gatewayv1beta1.Listener) (*v1alpha1.ManagedZone, error) {
var managedZones v1alpha1.ManagedZoneList
if err := r.List(ctx, &managedZones, client.InNamespace(ns)); err != nil {
if err := dh.List(ctx, &managedZones, client.InNamespace(ns)); err != nil {
log.FromContext(ctx).Error(err, "unable to list managed zones for gateway ", "in ns", ns)
return nil, err
}
Expand All @@ -325,37 +341,37 @@ func dnsRecordName(gatewayName, listenerName string) string {
return fmt.Sprintf("%s-%s", gatewayName, listenerName)
}

func (r *dnsHelper) createDNSRecordForListener(ctx context.Context, gateway *gatewayv1beta1.Gateway, dnsPolicy *v1alpha1.DNSPolicy, mz *v1alpha1.ManagedZone, listener gatewayv1beta1.Listener) (*v1alpha1.DNSRecord, error) {
func (dh *dnsHelper) createDNSRecordForListener(ctx context.Context, gateway *gatewayv1beta1.Gateway, dnsPolicy *v1alpha1.DNSPolicy, mz *v1alpha1.ManagedZone, listener gatewayv1beta1.Listener) (*v1alpha1.DNSRecord, error) {

log := log.FromContext(ctx)
log.Info("creating dns for gateway listener", "listener", listener.Name)
dnsRecord := r.buildDNSRecordForListener(gateway, dnsPolicy, listener, mz)
if err := controllerutil.SetControllerReference(mz, dnsRecord, r.Scheme()); err != nil {
dnsRecord := dh.buildDNSRecordForListener(gateway, dnsPolicy, listener, mz)
if err := controllerutil.SetControllerReference(mz, dnsRecord, dh.Scheme()); err != nil {
return dnsRecord, err
}

err := r.Create(ctx, dnsRecord, &client.CreateOptions{})
err := dh.Create(ctx, dnsRecord, &client.CreateOptions{})
if err != nil && !k8serrors.IsAlreadyExists(err) {
return dnsRecord, err
}
if err != nil && k8serrors.IsAlreadyExists(err) {
err = r.Get(ctx, client.ObjectKeyFromObject(dnsRecord), dnsRecord)
err = dh.Get(ctx, client.ObjectKeyFromObject(dnsRecord), dnsRecord)
if err != nil {
return dnsRecord, err
}
}
return dnsRecord, nil
}

func (r *dnsHelper) deleteDNSRecordForListener(ctx context.Context, owner metav1.Object, listener gatewayv1beta1.Listener) error {
func (dh *dnsHelper) deleteDNSRecordForListener(ctx context.Context, owner metav1.Object, listener gatewayv1beta1.Listener) error {
recordName := dnsRecordName(owner.GetName(), string(listener.Name))
dnsRecord := v1alpha1.DNSRecord{
ObjectMeta: metav1.ObjectMeta{
Name: recordName,
Namespace: owner.GetNamespace(),
},
}
return r.Delete(ctx, &dnsRecord, &client.DeleteOptions{})
return dh.Delete(ctx, &dnsRecord, &client.DeleteOptions{})
}

func isWildCardListener(l gatewayv1beta1.Listener) bool {
Expand Down
3 changes: 1 addition & 2 deletions pkg/controllers/dnspolicy/dnspolicy_controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -195,8 +195,7 @@ func (r *DNSPolicyReconciler) deleteResources(ctx context.Context, dnsPolicy *v1
if err != nil {
return err
}

if err := r.reconcileDNSRecords(ctx, dnsPolicy, gatewayDiffObj); err != nil {
if err = r.deleteDNSRecords(ctx, dnsPolicy); err != nil {
log.V(3).Info("error reconciling DNS records from delete, returning", "error", err)
return err
}
Expand Down
8 changes: 4 additions & 4 deletions pkg/controllers/dnspolicy/dnspolicy_dnsrecords.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,15 +25,15 @@ func (r *DNSPolicyReconciler) reconcileDNSRecords(ctx context.Context, dnsPolicy

for _, gw := range gwDiffObj.GatewaysWithInvalidPolicyRef {
log.V(1).Info("reconcileDNSRecords: gateway with invalid policy ref", "key", gw.Key())
err := r.deleteGatewayDNSRecords(ctx, gw.Gateway, dnsPolicy)
err := r.deleteDNSRecords(ctx, dnsPolicy)
if err != nil {
return err
}
}

// Reconcile DNSRecords for each gateway directly referred by the policy (existing and new)
for _, gw := range append(gwDiffObj.GatewaysWithValidPolicyRef, gwDiffObj.GatewaysMissingPolicyRef...) {
log.V(1).Info("reconcileDNSRecords: gateway with valid and missing policy ref", "key", gw.Key())
log.V(1).Info("reconcileDNSRecords: gateway with valid or missing policy ref", "key", gw.Key())
err := r.reconcileGatewayDNSRecords(ctx, gw.Gateway, dnsPolicy)
if err != nil {
return err
Expand Down Expand Up @@ -123,10 +123,10 @@ func (r *DNSPolicyReconciler) reconcileGatewayDNSRecords(ctx context.Context, ga
return nil
}

func (r *DNSPolicyReconciler) deleteGatewayDNSRecords(ctx context.Context, gateway *gatewayv1beta1.Gateway, dnsPolicy *v1alpha1.DNSPolicy) error {
func (r *DNSPolicyReconciler) deleteDNSRecords(ctx context.Context, dnsPolicy *v1alpha1.DNSPolicy) error {
log := crlog.FromContext(ctx)

listOptions := &client.ListOptions{LabelSelector: labels.SelectorFromSet(commonDNSRecordLabels(client.ObjectKeyFromObject(gateway), client.ObjectKeyFromObject(dnsPolicy)))}
listOptions := &client.ListOptions{LabelSelector: labels.SelectorFromSet(policyDNSRecordLabels(client.ObjectKeyFromObject(dnsPolicy)))}
recordsList := &v1alpha1.DNSRecordList{}
if err := r.Client().List(ctx, recordsList, listOptions); err != nil {
return err
Expand Down
38 changes: 27 additions & 11 deletions pkg/controllers/tlspolicy/tlspolicy_certmanager_certificates.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,14 +27,14 @@ func (r *TLSPolicyReconciler) reconcileCertificates(ctx context.Context, tlsPoli

for _, gw := range gwDiffObj.GatewaysWithInvalidPolicyRef {
log.V(1).Info("reconcileCertificates: gateway with invalid policy ref", "key", gw.Key())
if err := r.deleteGatewayCertificates(ctx, gw.Gateway, tlsPolicy); err != nil {
if err := r.deleteCertificates(ctx, tlsPolicy); err != nil {
return err
}
}

// Reconcile Certificates for each gateway directly referred by the policy (existing and new)
for _, gw := range append(gwDiffObj.GatewaysWithValidPolicyRef, gwDiffObj.GatewaysMissingPolicyRef...) {
log.V(1).Info("reconcileCertificates: gateway with valid and missing policy ref", "key", gw.Key())
log.V(1).Info("reconcileCertificates: gateway with valid or missing policy ref", "key", gw.Key())
if err := r.reconcileGatewayCertificates(ctx, gw.Gateway, tlsPolicy); err != nil {
return err
}
Expand All @@ -50,7 +50,7 @@ func (r *TLSPolicyReconciler) reconcileGatewayCertificates(ctx context.Context,

expectedCerts := r.expectedCertificatesForGateway(ctx, gateway, tlsPolicy)

if err := r.deleteUnexpectedGatewayCertificates(ctx, expectedCerts, gateway, tlsPolicy); err != nil {
if err := r.deleteUnexpectedCertificates(ctx, expectedCerts, tlsPolicy); err != nil {
return err
}

Expand All @@ -65,14 +65,14 @@ func (r *TLSPolicyReconciler) reconcileGatewayCertificates(ctx context.Context,
return nil
}

func (r *TLSPolicyReconciler) deleteGatewayCertificates(ctx context.Context, gateway *gatewayv1beta1.Gateway, tlsPolicy *v1alpha1.TLSPolicy) error {
return r.deleteUnexpectedGatewayCertificates(ctx, []*certmanv1.Certificate{}, gateway, tlsPolicy)
func (r *TLSPolicyReconciler) deleteCertificates(ctx context.Context, tlsPolicy *v1alpha1.TLSPolicy) error {
return r.deleteUnexpectedCertificates(ctx, []*certmanv1.Certificate{}, tlsPolicy)
}

func (r *TLSPolicyReconciler) deleteUnexpectedGatewayCertificates(ctx context.Context, expectedCerts []*certmanv1.Certificate, gateway *gatewayv1beta1.Gateway, tlsPolicy *v1alpha1.TLSPolicy) error {
func (r *TLSPolicyReconciler) deleteUnexpectedCertificates(ctx context.Context, expectedCerts []*certmanv1.Certificate, tlsPolicy *v1alpha1.TLSPolicy) error {
log := crlog.FromContext(ctx)

listOptions := &client.ListOptions{LabelSelector: labels.SelectorFromSet(tlsCertificateLabels(client.ObjectKeyFromObject(gateway), client.ObjectKeyFromObject(tlsPolicy)))}
listOptions := &client.ListOptions{LabelSelector: labels.SelectorFromSet(policyTLSCertificateLabels(client.ObjectKeyFromObject(tlsPolicy)))}
certList := &certmanv1.CertificateList{}
if err := r.Client().List(ctx, certList, listOptions); err != nil {
return err
Expand Down Expand Up @@ -126,7 +126,7 @@ func (r *TLSPolicyReconciler) expectedCertificatesForGateway(ctx context.Context
}

func (r *TLSPolicyReconciler) buildCertManagerCertificate(gateway *gatewayv1beta1.Gateway, tlsPolicy *v1alpha1.TLSPolicy, secretRef corev1.ObjectReference, hosts []string) *certmanv1.Certificate {
tlsCertLabels := tlsCertificateLabels(client.ObjectKeyFromObject(gateway), client.ObjectKeyFromObject(tlsPolicy))
tlsCertLabels := commonTLSCertificateLabels(client.ObjectKeyFromObject(gateway), client.ObjectKeyFromObject(tlsPolicy))

crt := &certmanv1.Certificate{
ObjectMeta: metav1.ObjectMeta{
Expand All @@ -148,12 +148,28 @@ func (r *TLSPolicyReconciler) buildCertManagerCertificate(gateway *gatewayv1beta
return crt
}

func tlsCertificateLabels(gwKey, apKey client.ObjectKey) map[string]string {
func commonTLSCertificateLabels(gwKey, apKey client.ObjectKey) map[string]string {
common := map[string]string{}
for k, v := range policyTLSCertificateLabels(apKey) {
common[k] = v
}
for k, v := range gatewayTLSCertificateLabels(gwKey) {
common[k] = v
}
return common
}

func policyTLSCertificateLabels(apKey client.ObjectKey) map[string]string {
return map[string]string{
TLSPolicyBackRefAnnotation: apKey.Name,
fmt.Sprintf("%s-namespace", TLSPolicyBackRefAnnotation): apKey.Namespace,
"gateway-namespace": gwKey.Namespace,
"gateway": gwKey.Name,
}
}

func gatewayTLSCertificateLabels(gwKey client.ObjectKey) map[string]string {
return map[string]string{
"gateway-namespace": gwKey.Namespace,
"gateway": gwKey.Name,
}
}

Expand Down
2 changes: 1 addition & 1 deletion pkg/controllers/tlspolicy/tlspolicy_controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ func (r *TLSPolicyReconciler) deleteResources(ctx context.Context, tlsPolicy *v1
return err
}

if err := r.reconcileCertificates(ctx, tlsPolicy, gatewayDiffObj); err != nil {
if err := r.deleteCertificates(ctx, tlsPolicy); err != nil {
return err
}

Expand Down
32 changes: 31 additions & 1 deletion test/integration/dnspolicy_controller_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ package integration

import (
"encoding/json"
"errors"
"fmt"
"time"

Expand Down Expand Up @@ -340,7 +341,8 @@ var _ = Describe("DNSPolicy", Ordered, func() {
Expect(err).ToNot(HaveOccurred())

for _, record := range dnsRecordList.Items {
Expect(k8sClient.Delete(ctx, &record)).ToNot(HaveOccurred())
err := k8sClient.Delete(ctx, &record)
Expect(client.IgnoreNotFound(err)).ToNot(HaveOccurred())
}
})

Expand Down Expand Up @@ -673,6 +675,34 @@ var _ = Describe("DNSPolicy", Ordered, func() {
return nil
}, time.Second*5, time.Second).Should(BeNil())
})

It("should remove dns record reference on policy deletion even if gateway is removed", func() {
createdDNSRecord := &v1alpha1.DNSRecord{}
Eventually(func() error { // DNS record exists
if err := k8sClient.Get(ctx, client.ObjectKey{Name: dnsRecordName, Namespace: testNamespace}, createdDNSRecord); err != nil {
return err
}
return nil
}, TestTimeoutMedium, TestRetryIntervalMedium).Should(BeNil())

err := k8sClient.Delete(ctx, gateway)
Expect(client.IgnoreNotFound(err)).ToNot(HaveOccurred())

dnsPolicy = testBuildDNSPolicyWithHealthCheck("test-dns-policy", TestPlacedGatewayName, testNamespace, nil)
err = k8sClient.Delete(ctx, dnsPolicy)
Expect(client.IgnoreNotFound(err)).ToNot(HaveOccurred())

Eventually(func() error { // DNS record removed
if err := k8sClient.Get(ctx, client.ObjectKey{Name: dnsRecordName, Namespace: testNamespace}, createdDNSRecord); err != nil {
if k8serrors.IsNotFound(err) {
return nil
}
return err
}
return errors.New("found dnsrecord when it should be deleted")
}, TestTimeoutMedium, TestRetryIntervalMedium).Should(BeNil())
})

})

Context("geo dnspolicy", func() {
Expand Down
19 changes: 13 additions & 6 deletions test/integration/tlspolicy_controller_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,23 +56,27 @@ var _ = Describe("TLSPolicy", Ordered, func() {
gatewayList := &gatewayv1beta1.GatewayList{}
Expect(k8sClient.List(ctx, gatewayList)).To(BeNil())
for _, gw := range gatewayList.Items {
k8sClient.Delete(ctx, &gw)
err := k8sClient.Delete(ctx, &gw)
Expect(client.IgnoreNotFound(err)).ToNot(HaveOccurred())
}
policyList := v1alpha1.TLSPolicyList{}
Expect(k8sClient.List(ctx, &policyList)).To(BeNil())
for _, policy := range policyList.Items {
k8sClient.Delete(ctx, &policy)
err := k8sClient.Delete(ctx, &policy)
Expect(client.IgnoreNotFound(err)).ToNot(HaveOccurred())
}
issuerList := certmanv1.IssuerList{}
Expect(k8sClient.List(ctx, &issuerList)).To(BeNil())
for _, issuer := range issuerList.Items {
k8sClient.Delete(ctx, &issuer)
err := k8sClient.Delete(ctx, &issuer)
Expect(client.IgnoreNotFound(err)).ToNot(HaveOccurred())
}
})

AfterAll(func() {
err := k8sClient.Delete(ctx, gatewayClass)
Expect(err).ToNot(HaveOccurred())
Expect(client.IgnoreNotFound(err)).ToNot(HaveOccurred())

})

Context("invalid target", func() {
Expand Down Expand Up @@ -522,7 +526,7 @@ var _ = Describe("TLSPolicy", Ordered, func() {
return nil
}, time.Second*120, time.Second).Should(BeNil())
})
It("should delete all tls certificates when tls policy is removed", func() {
It("should delete all tls certificates when tls policy is removed even if gateway is already removed", func() {
//confirm all expected certificates are present
Eventually(func() error {
certificateList := &certmanv1.CertificateList{}
Expand All @@ -533,8 +537,11 @@ var _ = Describe("TLSPolicy", Ordered, func() {
return nil
}, time.Second*10, time.Second).Should(BeNil())

// delete the gateway
Expect(client.IgnoreNotFound(k8sClient.Delete(ctx, gateway))).ToNot(HaveOccurred())

//delete the tls policy
Expect(k8sClient.Delete(ctx, tlsPolicy)).To(BeNil())
Expect(client.IgnoreNotFound(k8sClient.Delete(ctx, tlsPolicy))).ToNot(HaveOccurred())

//confirm all certificates have been deleted
Eventually(func() error {
Expand Down

0 comments on commit 2e1971e

Please sign in to comment.