From 95c852dcaf510ee21e5188358899e47916832110 Mon Sep 17 00:00:00 2001 From: Aswin Surayanarayanan Date: Tue, 8 Jun 2021 11:51:10 +0530 Subject: [PATCH 1/2] Add support for SRV record for headless services Fixes: #526 Signed-off-by: Aswin Surayanarayanan --- pkg/endpointslice/controller.go | 2 +- pkg/endpointslice/map.go | 57 +++++++++++++----- pkg/endpointslice/map_test.go | 10 +++- pkg/serviceimport/map.go | 5 +- plugin/lighthouse/handler.go | 21 ++++--- plugin/lighthouse/handler_test.go | 99 +++++++++++++++++++++++++++---- plugin/lighthouse/record.go | 79 +++++++++++++----------- 7 files changed, 199 insertions(+), 74 deletions(-) diff --git a/pkg/endpointslice/controller.go b/pkg/endpointslice/controller.go index 4fb1a1474..6809449fc 100644 --- a/pkg/endpointslice/controller.go +++ b/pkg/endpointslice/controller.go @@ -138,7 +138,7 @@ func (c *Controller) IsHealthy(name, namespace, clusterID string) bool { if endpointInfo != nil && endpointInfo.clusterInfo != nil { info := endpointInfo.clusterInfo[clusterID] if info != nil { - return len(info.ipList) > 0 + return len(info.recordList) > 0 } } diff --git a/pkg/endpointslice/map.go b/pkg/endpointslice/map.go index 3fab8c912..cfb40a594 100644 --- a/pkg/endpointslice/map.go +++ b/pkg/endpointslice/map.go @@ -22,8 +22,10 @@ import ( "github.com/submariner-io/admiral/pkg/log" "github.com/submariner-io/lighthouse/pkg/constants" + "github.com/submariner-io/lighthouse/pkg/serviceimport" discovery "k8s.io/api/discovery/v1beta1" "k8s.io/klog" + mcsv1a1 "sigs.k8s.io/mcs-api/pkg/apis/v1alpha1" ) type endpointInfo struct { @@ -32,8 +34,8 @@ type endpointInfo struct { } type clusterInfo struct { - hostIPs map[string][]string - ipList []string + hostRecords map[string][]serviceimport.DNSRecord + recordList []serviceimport.DNSRecord } type Map struct { @@ -41,7 +43,7 @@ type Map struct { sync.RWMutex } -func (m *Map) GetIPs(hostname, cluster, namespace, name string, checkCluster func(string) bool) ([]string, bool) { +func (m *Map) GetIPs(hostname, cluster, namespace, name string, checkCluster func(string) bool) ([]serviceimport.DNSRecord, bool) { key := keyFunc(name, namespace) clusterInfos := func() map[string]*clusterInfo { @@ -62,24 +64,24 @@ func (m *Map) GetIPs(hostname, cluster, namespace, name string, checkCluster fun switch { case cluster == "": - ips := make([]string, 0) + records := make([]serviceimport.DNSRecord, 0) for clusterID, info := range clusterInfos { if checkCluster == nil || checkCluster(clusterID) { - ips = append(ips, info.ipList...) + records = append(records, info.recordList...) } } - return ips, true + return records, true case clusterInfos[cluster] == nil: return nil, false case hostname == "": - return clusterInfos[cluster].ipList, true - case clusterInfos[cluster].hostIPs == nil: + return clusterInfos[cluster].recordList, true + case clusterInfos[cluster].hostRecords == nil: return nil, false default: - ips, ok := clusterInfos[cluster].hostIPs[hostname] - return ips, ok + records, ok := clusterInfos[cluster].hostRecords[hostname] + return records, ok } } @@ -115,16 +117,43 @@ func (m *Map) Put(es *discovery.EndpointSlice) { } epInfo.clusterInfo[cluster] = &clusterInfo{ - ipList: make([]string, 0), - hostIPs: make(map[string][]string), + recordList: make([]serviceimport.DNSRecord, 0), + hostRecords: make(map[string][]serviceimport.DNSRecord), + } + + mcsPorts := make([]mcsv1a1.ServicePort, len(es.Ports)) + + for i, port := range es.Ports { + mcsPort := mcsv1a1.ServicePort{ + Name: *port.Name, + Protocol: *port.Protocol, + AppProtocol: port.AppProtocol, + Port: *port.Port, + } + mcsPorts[i] = mcsPort } for _, endpoint := range es.Endpoints { + var records []serviceimport.DNSRecord + + for _, address := range endpoint.Addresses { + record := serviceimport.DNSRecord{ + IP: address, + Ports: mcsPorts, + } + + if endpoint.Hostname != nil { + record.HostName = *endpoint.Hostname + } + + records = append(records, record) + } + if endpoint.Hostname != nil { - epInfo.clusterInfo[cluster].hostIPs[*endpoint.Hostname] = endpoint.Addresses + epInfo.clusterInfo[cluster].hostRecords[*endpoint.Hostname] = records } - epInfo.clusterInfo[cluster].ipList = append(epInfo.clusterInfo[cluster].ipList, endpoint.Addresses...) + epInfo.clusterInfo[cluster].recordList = append(epInfo.clusterInfo[cluster].recordList, records...) } klog.V(log.DEBUG).Infof("Adding clusterInfo %#v for EndpointSlice %q in %q", epInfo.clusterInfo[cluster], es.Name, cluster) diff --git a/pkg/endpointslice/map_test.go b/pkg/endpointslice/map_test.go index f1a793c07..8a3ca06c9 100644 --- a/pkg/endpointslice/map_test.go +++ b/pkg/endpointslice/map_test.go @@ -20,6 +20,8 @@ package endpointslice_test import ( "sort" + "github.com/submariner-io/lighthouse/pkg/serviceimport" + . "github.com/onsi/ginkgo" . "github.com/onsi/gomega" "github.com/submariner-io/lighthouse/pkg/endpointslice" @@ -55,7 +57,7 @@ var _ = Describe("EndpointSlice Map", func() { return clusterStatusMap[id] } - getIPs := func(hostname, cluster, ns, name string) []string { + getIPs := func(hostname, cluster, ns, name string) []serviceimport.DNSRecord { ips, found := endpointSliceMap.GetIPs(hostname, cluster, ns, name, checkCluster) Expect(found).To(BeTrue()) return ips @@ -64,7 +66,11 @@ var _ = Describe("EndpointSlice Map", func() { expectIPs := func(hostname, cluster, ns, name string, expIPs []string) { sort.Strings(expIPs) for i := 0; i < 5; i++ { - ips := getIPs(hostname, cluster, namespace1, service1) + var ips []string + records := getIPs(hostname, cluster, namespace1, service1) + for _, record := range records { + ips = append(ips, record.IP) + } sort.Strings(ips) Expect(ips).To(Equal(expIPs)) } diff --git a/pkg/serviceimport/map.go b/pkg/serviceimport/map.go index 73adf501c..2584b9a6c 100644 --- a/pkg/serviceimport/map.go +++ b/pkg/serviceimport/map.go @@ -26,8 +26,9 @@ import ( ) type DNSRecord struct { - IP string - Ports []mcsv1a1.ServicePort + IP string + Ports []mcsv1a1.ServicePort + HostName string } type clusterInfo struct { diff --git a/plugin/lighthouse/handler.go b/plugin/lighthouse/handler.go index 8b0cc5948..83109dc5d 100644 --- a/plugin/lighthouse/handler.go +++ b/plugin/lighthouse/handler.go @@ -68,24 +68,27 @@ func (lh *Lighthouse) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns func (lh *Lighthouse) getDNSRecord(zone string, state request.Request, ctx context.Context, w dns.ResponseWriter, r *dns.Msg, pReq recordRequest) (int, error) { + var isHeadless bool var ( - ips []string - found bool - record *serviceimport.DNSRecord + dnsRecords []serviceimport.DNSRecord + found bool + record *serviceimport.DNSRecord ) record, found = lh.getClusterIPForSvc(pReq) if !found { - ips, found = lh.endpointSlices.GetIPs(pReq.hostname, pReq.cluster, pReq.namespace, pReq.service, lh.clusterStatus.IsConnected) + dnsRecords, found = lh.endpointSlices.GetIPs(pReq.hostname, pReq.cluster, pReq.namespace, pReq.service, lh.clusterStatus.IsConnected) if !found { log.Debugf("No record found for %q", state.QName()) return lh.nextOrFailure(state.Name(), ctx, w, r, dns.RcodeNameError, "record not found") } + + isHeadless = true } else if record != nil && record.IP != "" { - ips = []string{record.IP} + dnsRecords = append(dnsRecords, *record) } - if len(ips) == 0 { + if len(dnsRecords) == 0 { log.Debugf("Couldn't find a connected cluster or valid IPs for %q", state.QName()) return lh.emptyResponse(state) } @@ -95,12 +98,12 @@ func (lh *Lighthouse) getDNSRecord(zone string, state request.Request, ctx conte return lh.emptyResponse(state) } - var records []dns.RR + records := make([]dns.RR, 0) if state.QType() == dns.TypeA { - records = lh.createARecords(ips, state) + records = lh.createARecords(dnsRecords, state) } else if state.QType() == dns.TypeSRV { - records = lh.createSRVRecords(record, state, pReq, zone) + records = lh.createSRVRecords(dnsRecords, state, pReq, zone, isHeadless) } if len(records) == 0 { diff --git a/plugin/lighthouse/handler_test.go b/plugin/lighthouse/handler_test.go index 696e4d544..cabc2e2f7 100644 --- a/plugin/lighthouse/handler_test.go +++ b/plugin/lighthouse/handler_test.go @@ -54,6 +54,8 @@ const ( portNumber1 = int32(8080) protcol2 = v1.ProtocolUDP portNumber2 = int32(53) + hostName1 = "hostName1" + hostName2 = "hostName2" ) var _ = Describe("Lighthouse DNS plugin Handler", func() { @@ -630,7 +632,7 @@ func testHeadlessService() { JustBeforeEach(func() { lh.serviceImports.Put(newServiceImport(namespace1, service1, clusterID, "", portName1, portNumber1, protcol1, mcsv1a1.Headless)) - lh.endpointSlices.Put(newEndpointSlice(namespace1, service1, clusterID, []string{})) + lh.endpointSlices.Put(newEndpointSlice(namespace1, service1, clusterID, portName1, []string{}, []string{}, portNumber1, protcol1)) }) It("should succeed and return empty response (NODATA)", func() { executeTestCase(lh, rec, test.Case{ @@ -640,13 +642,21 @@ func testHeadlessService() { Answer: []dns.RR{}, }) }) + It("should succeed and return empty response (NODATA)", func() { + executeTestCase(lh, rec, test.Case{ + Qname: service1 + "." + namespace1 + ".svc.clusterset.local.", + Qtype: dns.TypeSRV, + Rcode: dns.RcodeSuccess, + Answer: []dns.RR{}, + }) + }) }) - When("headless service has one IP", func() { JustBeforeEach(func() { lh.serviceImports.Put(newServiceImport(namespace1, service1, clusterID, "", portName1, portNumber1, protcol1, mcsv1a1.Headless)) - lh.endpointSlices.Put(newEndpointSlice(namespace1, service1, clusterID, []string{endpointIP})) + lh.endpointSlices.Put(newEndpointSlice(namespace1, service1, clusterID, portName1, []string{hostName1}, []string{endpointIP}, + portNumber1, protcol1)) }) It("should succeed and write an A record response", func() { executeTestCase(lh, rec, test.Case{ @@ -658,13 +668,38 @@ func testHeadlessService() { }, }) }) + It("should succeed and write an SRV record response", func() { + executeTestCase(lh, rec, test.Case{ + Qname: service1 + "." + namespace1 + ".svc.clusterset.local.", + Qtype: dns.TypeSRV, + Rcode: dns.RcodeSuccess, + Answer: []dns.RR{ + test.SRV(service1 + "." + namespace1 + ".svc.clusterset.local. 5 IN SRV 0 50 " + + strconv.Itoa(int(portNumber1)) + " " + hostName1 + "." + service1 + "." + namespace1 + ".svc.clusterset.local."), + }, + }) + }) + It("should succeed and write an SRV record response for query with cluster name", func() { + executeTestCase(lh, rec, test.Case{ + Qname: clusterID + "." + service1 + "." + namespace1 + ".svc.clusterset.local.", + Qtype: dns.TypeSRV, + Rcode: dns.RcodeSuccess, + Answer: []dns.RR{ + test.SRV(clusterID + "." + service1 + "." + namespace1 + ".svc.clusterset.local. 5 IN SRV 0 50 " + + strconv.Itoa(int(portNumber1)) + " " + hostName1 + "." + clusterID + "." + service1 + "." + namespace1 + + ".svc.clusterset.local."), + }, + }) + }) }) When("headless service has two IPs", func() { JustBeforeEach(func() { lh.serviceImports.Put(newServiceImport(namespace1, service1, clusterID, "", portName1, portNumber1, protcol1, mcsv1a1.Headless)) - lh.endpointSlices.Put(newEndpointSlice(namespace1, service1, clusterID, []string{endpointIP, endpointIP2})) + lh.endpointSlices.Put(newEndpointSlice(namespace1, service1, clusterID, portName1, []string{hostName1, hostName2}, + []string{endpointIP, endpointIP2}, + portNumber1, protcol1)) }) It("should succeed and write two A records as response", func() { executeTestCase(lh, rec, test.Case{ @@ -677,6 +712,34 @@ func testHeadlessService() { }, }) }) + It("should succeed and write an SRV record response", func() { + executeTestCase(lh, rec, test.Case{ + Qname: service1 + "." + namespace1 + ".svc.clusterset.local.", + Qtype: dns.TypeSRV, + Rcode: dns.RcodeSuccess, + Answer: []dns.RR{ + test.SRV(service1 + "." + namespace1 + ".svc.clusterset.local. 5 IN SRV 0 50 " + + strconv.Itoa(int(portNumber1)) + " " + hostName1 + "." + service1 + "." + namespace1 + ".svc.clusterset.local."), + test.SRV(service1 + "." + namespace1 + ".svc.clusterset.local. 5 IN SRV 0 50 " + + strconv.Itoa(int(portNumber1)) + " " + hostName2 + "." + service1 + "." + namespace1 + ".svc.clusterset.local."), + }, + }) + }) + It("should succeed and write an SRV record response when port and protocol is queried", func() { + executeTestCase(lh, rec, test.Case{ + Qname: portName1 + "." + string(protcol1) + "." + service1 + "." + namespace1 + ".svc.clusterset.local.", + Qtype: dns.TypeSRV, + Rcode: dns.RcodeSuccess, + Answer: []dns.RR{ + test.SRV(portName1 + "." + string(protcol1) + "." + service1 + "." + namespace1 + ".svc.clusterset.local." + + " 5 IN SRV 0 50 " + strconv.Itoa(int(portNumber1)) + " " + hostName1 + "." + service1 + "." + + namespace1 + ".svc.clusterset.local."), + test.SRV(portName1 + "." + string(protcol1) + "." + service1 + "." + namespace1 + ".svc.clusterset.local." + + " 5 IN SRV 0 50 " + strconv.Itoa(int(portNumber1)) + " " + hostName2 + "." + service1 + "." + + namespace1 + ".svc.clusterset.local."), + }, + }) + }) }) When("headless service is present in two clusters", func() { @@ -685,8 +748,10 @@ func testHeadlessService() { portNumber1, protcol1, mcsv1a1.Headless)) lh.serviceImports.Put(newServiceImport(namespace1, service1, clusterID2, "", portName1, portNumber1, protcol1, mcsv1a1.Headless)) - lh.endpointSlices.Put(newEndpointSlice(namespace1, service1, clusterID, []string{endpointIP})) - lh.endpointSlices.Put(newEndpointSlice(namespace1, service1, clusterID2, []string{endpointIP2})) + lh.endpointSlices.Put(newEndpointSlice(namespace1, service1, clusterID, portName1, []string{hostName1}, []string{endpointIP}, + portNumber1, protcol1)) + lh.endpointSlices.Put(newEndpointSlice(namespace1, service1, clusterID2, portName1, []string{hostName2}, []string{endpointIP2}, + portNumber1, protcol1)) mockCs.clusterStatusMap[clusterID2] = true }) When("no cluster is requested", func() { @@ -966,7 +1031,7 @@ func setupServiceImportMap() *serviceimport.Map { func setupEndpointSliceMap() *endpointslice.Map { esMap := endpointslice.NewMap() - esMap.Put(newEndpointSlice(namespace1, service1, clusterID, []string{endpointIP})) + esMap.Put(newEndpointSlice(namespace1, service1, clusterID, portName1, []string{hostName1}, []string{endpointIP}, portNumber1, protcol1)) return esMap } @@ -1006,7 +1071,18 @@ func newServiceImport(namespace, name, clusterID, serviceIP, portName string, } } -func newEndpointSlice(namespace, name, clusterID string, endpointIPs []string) *discovery.EndpointSlice { +func newEndpointSlice(namespace, name, clusterID, portName string, hostName, endpointIPs []string, portNumber int32, + protocol v1.Protocol) *discovery.EndpointSlice { + endpoints := make([]discovery.Endpoint, len(endpointIPs)) + + for i := range endpointIPs { + endpoint := discovery.Endpoint{ + Addresses: []string{endpointIPs[i]}, + Hostname: &hostName[i], + } + endpoints[i] = endpoint + } + return &discovery.EndpointSlice{ ObjectMeta: metav1.ObjectMeta{ Name: name, @@ -1020,9 +1096,12 @@ func newEndpointSlice(namespace, name, clusterID string, endpointIPs []string) * }, }, AddressType: discovery.AddressTypeIPv4, - Endpoints: []discovery.Endpoint{ + Endpoints: endpoints, + Ports: []discovery.EndpointPort{ { - Addresses: endpointIPs, + Name: &portName, + Protocol: &protocol, + Port: &portNumber, }, }, } diff --git a/plugin/lighthouse/record.go b/plugin/lighthouse/record.go index 64ba56e0f..3f95b37f9 100644 --- a/plugin/lighthouse/record.go +++ b/plugin/lighthouse/record.go @@ -27,57 +27,64 @@ import ( "sigs.k8s.io/mcs-api/pkg/apis/v1alpha1" ) -func (lh *Lighthouse) createARecords(ips []string, state request.Request) []dns.RR { +func (lh *Lighthouse) createARecords(dnsrecords []serviceimport.DNSRecord, state request.Request) []dns.RR { records := make([]dns.RR, 0) - for _, ip := range ips { - record := &dns.A{Hdr: dns.RR_Header{Name: state.QName(), Rrtype: dns.TypeA, Class: state.QClass(), Ttl: lh.ttl}, A: net.ParseIP(ip).To4()} - log.Debugf("rr is %v", record) - records = append(records, record) + for _, record := range dnsrecords { + dnsRecord := &dns.A{Hdr: dns.RR_Header{Name: state.QName(), Rrtype: dns.TypeA, Class: state.QClass(), + Ttl: lh.ttl}, A: net.ParseIP(record.IP).To4()} + records = append(records, dnsRecord) } return records } -func (lh *Lighthouse) createSRVRecords(record *serviceimport.DNSRecord, state request.Request, pReq recordRequest, zone string) []dns.RR { - var reqPorts []v1alpha1.ServicePort - - if pReq.port == "" { - reqPorts = record.Ports - } else { - log.Debugf("Requested port %q, protocol %q for SRV", pReq.port, pReq.protocol) - for _, port := range record.Ports { - name := strings.ToLower(port.Name) - protocol := strings.ToLower(string(port.Protocol)) - - log.Debugf("Checking port %q, protocol %q", name, protocol) - if name == pReq.port && protocol == pReq.protocol { - reqPorts = append(reqPorts, port) +func (lh *Lighthouse) createSRVRecords(dnsrecords []serviceimport.DNSRecord, state request.Request, pReq recordRequest, zone string, + isHeadless bool) []dns.RR { + var records []dns.RR + + for _, dnsRecord := range dnsrecords { + var reqPorts []v1alpha1.ServicePort + + if pReq.port == "" { + reqPorts = dnsRecord.Ports + } else { + log.Debugf("Requested port %q, protocol %q for SRV", pReq.port, pReq.protocol) + for _, port := range dnsRecord.Ports { + name := strings.ToLower(port.Name) + protocol := strings.ToLower(string(port.Protocol)) + + log.Debugf("Checking port %q, protocol %q", name, protocol) + if name == pReq.port && protocol == pReq.protocol { + reqPorts = append(reqPorts, port) + } } } - } - if len(reqPorts) == 0 { - return nil - } + if len(reqPorts) == 0 { + return nil + } - target := pReq.service + "." + pReq.namespace + ".svc." + zone + target := pReq.service + "." + pReq.namespace + ".svc." + zone - if pReq.cluster != "" { - target = pReq.cluster + "." + target - } + if pReq.cluster != "" { + target = pReq.cluster + "." + target + } - records := make([]dns.RR, len(reqPorts)) + if isHeadless { + target = dnsRecord.HostName + "." + target + } - for index, port := range reqPorts { - record := &dns.SRV{ - Hdr: dns.RR_Header{Name: state.QName(), Rrtype: dns.TypeSRV, Class: state.QClass(), Ttl: lh.ttl}, - Priority: 0, - Weight: 50, - Port: uint16(port.Port), - Target: target, + for _, port := range reqPorts { + record := &dns.SRV{ + Hdr: dns.RR_Header{Name: state.QName(), Rrtype: dns.TypeSRV, Class: state.QClass(), Ttl: lh.ttl}, + Priority: 0, + Weight: 50, + Port: uint16(port.Port), + Target: target, + } + records = append(records, record) } - records[index] = record } return records From a834072acee55cdb6c2316521f688198179361c6 Mon Sep 17 00:00:00 2001 From: Aswin Surayanarayanan Date: Thu, 10 Jun 2021 18:39:00 +0530 Subject: [PATCH 2/2] Apply review comments Signed-off-by: Aswin Surayanarayanan --- pkg/endpointslice/map.go | 2 +- pkg/endpointslice/map_test.go | 6 +++--- plugin/lighthouse/handler.go | 3 ++- plugin/lighthouse/handler_test.go | 27 +++++++++++++++++++++++++++ plugin/lighthouse/parse.go | 17 +++++++++++++---- 5 files changed, 46 insertions(+), 9 deletions(-) diff --git a/pkg/endpointslice/map.go b/pkg/endpointslice/map.go index cfb40a594..6acf3af53 100644 --- a/pkg/endpointslice/map.go +++ b/pkg/endpointslice/map.go @@ -43,7 +43,7 @@ type Map struct { sync.RWMutex } -func (m *Map) GetIPs(hostname, cluster, namespace, name string, checkCluster func(string) bool) ([]serviceimport.DNSRecord, bool) { +func (m *Map) GetDNSRecords(hostname, cluster, namespace, name string, checkCluster func(string) bool) ([]serviceimport.DNSRecord, bool) { key := keyFunc(name, namespace) clusterInfos := func() map[string]*clusterInfo { diff --git a/pkg/endpointslice/map_test.go b/pkg/endpointslice/map_test.go index 8a3ca06c9..d209cd858 100644 --- a/pkg/endpointslice/map_test.go +++ b/pkg/endpointslice/map_test.go @@ -57,8 +57,8 @@ var _ = Describe("EndpointSlice Map", func() { return clusterStatusMap[id] } - getIPs := func(hostname, cluster, ns, name string) []serviceimport.DNSRecord { - ips, found := endpointSliceMap.GetIPs(hostname, cluster, ns, name, checkCluster) + getRecords := func(hostname, cluster, ns, name string) []serviceimport.DNSRecord { + ips, found := endpointSliceMap.GetDNSRecords(hostname, cluster, ns, name, checkCluster) Expect(found).To(BeTrue()) return ips } @@ -67,7 +67,7 @@ var _ = Describe("EndpointSlice Map", func() { sort.Strings(expIPs) for i := 0; i < 5; i++ { var ips []string - records := getIPs(hostname, cluster, namespace1, service1) + records := getRecords(hostname, cluster, namespace1, service1) for _, record := range records { ips = append(ips, record.IP) } diff --git a/plugin/lighthouse/handler.go b/plugin/lighthouse/handler.go index 83109dc5d..7c7837d83 100644 --- a/plugin/lighthouse/handler.go +++ b/plugin/lighthouse/handler.go @@ -77,7 +77,8 @@ func (lh *Lighthouse) getDNSRecord(zone string, state request.Request, ctx conte record, found = lh.getClusterIPForSvc(pReq) if !found { - dnsRecords, found = lh.endpointSlices.GetIPs(pReq.hostname, pReq.cluster, pReq.namespace, pReq.service, lh.clusterStatus.IsConnected) + dnsRecords, found = lh.endpointSlices.GetDNSRecords(pReq.hostname, pReq.cluster, pReq.namespace, + pReq.service, lh.clusterStatus.IsConnected) if !found { log.Debugf("No record found for %q", state.QName()) return lh.nextOrFailure(state.Name(), ctx, w, r, dns.RcodeNameError, "record not found") diff --git a/plugin/lighthouse/handler_test.go b/plugin/lighthouse/handler_test.go index cabc2e2f7..acd7e5161 100644 --- a/plugin/lighthouse/handler_test.go +++ b/plugin/lighthouse/handler_test.go @@ -740,6 +740,21 @@ func testHeadlessService() { }, }) }) + It("should succeed and write an SRV record response when port and protocol is queried with underscore prefix", func() { + executeTestCase(lh, rec, test.Case{ + Qname: "_" + portName1 + "." + "_" + string(protcol1) + "." + service1 + "." + namespace1 + ".svc.clusterset.local.", + Qtype: dns.TypeSRV, + Rcode: dns.RcodeSuccess, + Answer: []dns.RR{ + test.SRV("_" + portName1 + "." + "_" + string(protcol1) + "." + service1 + "." + namespace1 + ".svc.clusterset.local." + + " 5 IN SRV 0 50 " + strconv.Itoa(int(portNumber1)) + " " + hostName1 + "." + service1 + "." + + namespace1 + ".svc.clusterset.local."), + test.SRV("_" + portName1 + "." + "_" + string(protcol1) + "." + service1 + "." + namespace1 + ".svc.clusterset.local." + + " 5 IN SRV 0 50 " + strconv.Itoa(int(portNumber1)) + " " + hostName2 + "." + service1 + "." + + namespace1 + ".svc.clusterset.local."), + }, + }) + }) }) When("headless service is present in two clusters", func() { @@ -1006,6 +1021,18 @@ func testSRVMultiplePorts() { }, }) }) + It("with HTTP portname should return TCP port with underscore prefix", func() { + executeTestCase(lh, rec, test.Case{ + Qname: "_" + portName1 + "." + "_" + string(protcol1) + "." + service1 + "." + namespace1 + ".svc.clusterset.local.", + Qtype: dns.TypeSRV, + Rcode: dns.RcodeSuccess, + Answer: []dns.RR{ + test.SRV("_" + portName1 + "." + "_" + string(protcol1) + "." + service1 + "." + namespace1 + + ".svc.clusterset.local. 5 IN SRV 0 50 " + strconv.Itoa(int(portNumber1)) + " " + service1 + "." + + namespace1 + ".svc.clusterset.local."), + }, + }) + }) }) } diff --git a/plugin/lighthouse/parse.go b/plugin/lighthouse/parse.go index 94bda31a3..130d1e96f 100644 --- a/plugin/lighthouse/parse.go +++ b/plugin/lighthouse/parse.go @@ -128,13 +128,13 @@ func parseSegments(segs []string, count int, r recordRequest, state request.Requ case 0: // cluster only r.cluster = segs[count] case 1: // endpoint only - r.protocol = segs[count] - r.port = segs[count-1] + r.protocol = stripUnderscore(segs[count]) + r.port = stripUnderscore(segs[count-1]) case 2: // service and port r.cluster = segs[count] - r.protocol = segs[count-1] - r.port = segs[count-2] + r.protocol = stripUnderscore(segs[count-1]) + r.port = stripUnderscore(segs[count-2]) default: // too long return r, errInvalidRequest } @@ -142,3 +142,12 @@ func parseSegments(segs []string, count int, r recordRequest, state request.Requ return r, nil } + +// stripUnderscore removes a prefixed underscore from s. +func stripUnderscore(s string) string { + if s[0] != '_' { + return s + } + + return s[1:] +}