Skip to content

Commit

Permalink
Add support for SRV record for headless services
Browse files Browse the repository at this point in the history
Fixes: #526
Signed-off-by: Aswin Surayanarayanan <[email protected]>
  • Loading branch information
aswinsuryan committed Jun 9, 2021
1 parent 9a1a924 commit 95c852d
Show file tree
Hide file tree
Showing 7 changed files with 199 additions and 74 deletions.
2 changes: 1 addition & 1 deletion pkg/endpointslice/controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ func (c *Controller) IsHealthy(name, namespace, clusterID string) bool {
if endpointInfo != nil && endpointInfo.clusterInfo != nil {
info := endpointInfo.clusterInfo[clusterID]
if info != nil {
return len(info.ipList) > 0
return len(info.recordList) > 0
}
}

Expand Down
57 changes: 43 additions & 14 deletions pkg/endpointslice/map.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,10 @@ import (

"github.com/submariner-io/admiral/pkg/log"
"github.com/submariner-io/lighthouse/pkg/constants"
"github.com/submariner-io/lighthouse/pkg/serviceimport"
discovery "k8s.io/api/discovery/v1beta1"
"k8s.io/klog"
mcsv1a1 "sigs.k8s.io/mcs-api/pkg/apis/v1alpha1"
)

type endpointInfo struct {
Expand All @@ -32,16 +34,16 @@ type endpointInfo struct {
}

type clusterInfo struct {
hostIPs map[string][]string
ipList []string
hostRecords map[string][]serviceimport.DNSRecord
recordList []serviceimport.DNSRecord
}

type Map struct {
epMap map[string]*endpointInfo
sync.RWMutex
}

func (m *Map) GetIPs(hostname, cluster, namespace, name string, checkCluster func(string) bool) ([]string, bool) {
func (m *Map) GetIPs(hostname, cluster, namespace, name string, checkCluster func(string) bool) ([]serviceimport.DNSRecord, bool) {
key := keyFunc(name, namespace)

clusterInfos := func() map[string]*clusterInfo {
Expand All @@ -62,24 +64,24 @@ func (m *Map) GetIPs(hostname, cluster, namespace, name string, checkCluster fun

switch {
case cluster == "":
ips := make([]string, 0)
records := make([]serviceimport.DNSRecord, 0)

for clusterID, info := range clusterInfos {
if checkCluster == nil || checkCluster(clusterID) {
ips = append(ips, info.ipList...)
records = append(records, info.recordList...)
}
}

return ips, true
return records, true
case clusterInfos[cluster] == nil:
return nil, false
case hostname == "":
return clusterInfos[cluster].ipList, true
case clusterInfos[cluster].hostIPs == nil:
return clusterInfos[cluster].recordList, true
case clusterInfos[cluster].hostRecords == nil:
return nil, false
default:
ips, ok := clusterInfos[cluster].hostIPs[hostname]
return ips, ok
records, ok := clusterInfos[cluster].hostRecords[hostname]
return records, ok
}
}

Expand Down Expand Up @@ -115,16 +117,43 @@ func (m *Map) Put(es *discovery.EndpointSlice) {
}

epInfo.clusterInfo[cluster] = &clusterInfo{
ipList: make([]string, 0),
hostIPs: make(map[string][]string),
recordList: make([]serviceimport.DNSRecord, 0),
hostRecords: make(map[string][]serviceimport.DNSRecord),
}

mcsPorts := make([]mcsv1a1.ServicePort, len(es.Ports))

for i, port := range es.Ports {
mcsPort := mcsv1a1.ServicePort{
Name: *port.Name,
Protocol: *port.Protocol,
AppProtocol: port.AppProtocol,
Port: *port.Port,
}
mcsPorts[i] = mcsPort
}

for _, endpoint := range es.Endpoints {
var records []serviceimport.DNSRecord

for _, address := range endpoint.Addresses {
record := serviceimport.DNSRecord{
IP: address,
Ports: mcsPorts,
}

if endpoint.Hostname != nil {
record.HostName = *endpoint.Hostname
}

records = append(records, record)
}

if endpoint.Hostname != nil {
epInfo.clusterInfo[cluster].hostIPs[*endpoint.Hostname] = endpoint.Addresses
epInfo.clusterInfo[cluster].hostRecords[*endpoint.Hostname] = records
}

epInfo.clusterInfo[cluster].ipList = append(epInfo.clusterInfo[cluster].ipList, endpoint.Addresses...)
epInfo.clusterInfo[cluster].recordList = append(epInfo.clusterInfo[cluster].recordList, records...)
}

klog.V(log.DEBUG).Infof("Adding clusterInfo %#v for EndpointSlice %q in %q", epInfo.clusterInfo[cluster], es.Name, cluster)
Expand Down
10 changes: 8 additions & 2 deletions pkg/endpointslice/map_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ package endpointslice_test
import (
"sort"

"github.com/submariner-io/lighthouse/pkg/serviceimport"

. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
"github.com/submariner-io/lighthouse/pkg/endpointslice"
Expand Down Expand Up @@ -55,7 +57,7 @@ var _ = Describe("EndpointSlice Map", func() {
return clusterStatusMap[id]
}

getIPs := func(hostname, cluster, ns, name string) []string {
getIPs := func(hostname, cluster, ns, name string) []serviceimport.DNSRecord {
ips, found := endpointSliceMap.GetIPs(hostname, cluster, ns, name, checkCluster)
Expect(found).To(BeTrue())
return ips
Expand All @@ -64,7 +66,11 @@ var _ = Describe("EndpointSlice Map", func() {
expectIPs := func(hostname, cluster, ns, name string, expIPs []string) {
sort.Strings(expIPs)
for i := 0; i < 5; i++ {
ips := getIPs(hostname, cluster, namespace1, service1)
var ips []string
records := getIPs(hostname, cluster, namespace1, service1)
for _, record := range records {
ips = append(ips, record.IP)
}
sort.Strings(ips)
Expect(ips).To(Equal(expIPs))
}
Expand Down
5 changes: 3 additions & 2 deletions pkg/serviceimport/map.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,9 @@ import (
)

type DNSRecord struct {
IP string
Ports []mcsv1a1.ServicePort
IP string
Ports []mcsv1a1.ServicePort
HostName string
}

type clusterInfo struct {
Expand Down
21 changes: 12 additions & 9 deletions plugin/lighthouse/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,24 +68,27 @@ func (lh *Lighthouse) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns

func (lh *Lighthouse) getDNSRecord(zone string, state request.Request, ctx context.Context, w dns.ResponseWriter,
r *dns.Msg, pReq recordRequest) (int, error) {
var isHeadless bool
var (
ips []string
found bool
record *serviceimport.DNSRecord
dnsRecords []serviceimport.DNSRecord
found bool
record *serviceimport.DNSRecord
)

record, found = lh.getClusterIPForSvc(pReq)
if !found {
ips, found = lh.endpointSlices.GetIPs(pReq.hostname, pReq.cluster, pReq.namespace, pReq.service, lh.clusterStatus.IsConnected)
dnsRecords, found = lh.endpointSlices.GetIPs(pReq.hostname, pReq.cluster, pReq.namespace, pReq.service, lh.clusterStatus.IsConnected)
if !found {
log.Debugf("No record found for %q", state.QName())
return lh.nextOrFailure(state.Name(), ctx, w, r, dns.RcodeNameError, "record not found")
}

isHeadless = true
} else if record != nil && record.IP != "" {
ips = []string{record.IP}
dnsRecords = append(dnsRecords, *record)
}

if len(ips) == 0 {
if len(dnsRecords) == 0 {
log.Debugf("Couldn't find a connected cluster or valid IPs for %q", state.QName())
return lh.emptyResponse(state)
}
Expand All @@ -95,12 +98,12 @@ func (lh *Lighthouse) getDNSRecord(zone string, state request.Request, ctx conte
return lh.emptyResponse(state)
}

var records []dns.RR
records := make([]dns.RR, 0)

if state.QType() == dns.TypeA {
records = lh.createARecords(ips, state)
records = lh.createARecords(dnsRecords, state)
} else if state.QType() == dns.TypeSRV {
records = lh.createSRVRecords(record, state, pReq, zone)
records = lh.createSRVRecords(dnsRecords, state, pReq, zone, isHeadless)
}

if len(records) == 0 {
Expand Down
99 changes: 89 additions & 10 deletions plugin/lighthouse/handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,8 @@ const (
portNumber1 = int32(8080)
protcol2 = v1.ProtocolUDP
portNumber2 = int32(53)
hostName1 = "hostName1"
hostName2 = "hostName2"
)

var _ = Describe("Lighthouse DNS plugin Handler", func() {
Expand Down Expand Up @@ -630,7 +632,7 @@ func testHeadlessService() {
JustBeforeEach(func() {
lh.serviceImports.Put(newServiceImport(namespace1, service1, clusterID, "", portName1,
portNumber1, protcol1, mcsv1a1.Headless))
lh.endpointSlices.Put(newEndpointSlice(namespace1, service1, clusterID, []string{}))
lh.endpointSlices.Put(newEndpointSlice(namespace1, service1, clusterID, portName1, []string{}, []string{}, portNumber1, protcol1))
})
It("should succeed and return empty response (NODATA)", func() {
executeTestCase(lh, rec, test.Case{
Expand All @@ -640,13 +642,21 @@ func testHeadlessService() {
Answer: []dns.RR{},
})
})
It("should succeed and return empty response (NODATA)", func() {
executeTestCase(lh, rec, test.Case{
Qname: service1 + "." + namespace1 + ".svc.clusterset.local.",
Qtype: dns.TypeSRV,
Rcode: dns.RcodeSuccess,
Answer: []dns.RR{},
})
})
})

When("headless service has one IP", func() {
JustBeforeEach(func() {
lh.serviceImports.Put(newServiceImport(namespace1, service1, clusterID, "", portName1,
portNumber1, protcol1, mcsv1a1.Headless))
lh.endpointSlices.Put(newEndpointSlice(namespace1, service1, clusterID, []string{endpointIP}))
lh.endpointSlices.Put(newEndpointSlice(namespace1, service1, clusterID, portName1, []string{hostName1}, []string{endpointIP},
portNumber1, protcol1))
})
It("should succeed and write an A record response", func() {
executeTestCase(lh, rec, test.Case{
Expand All @@ -658,13 +668,38 @@ func testHeadlessService() {
},
})
})
It("should succeed and write an SRV record response", func() {
executeTestCase(lh, rec, test.Case{
Qname: service1 + "." + namespace1 + ".svc.clusterset.local.",
Qtype: dns.TypeSRV,
Rcode: dns.RcodeSuccess,
Answer: []dns.RR{
test.SRV(service1 + "." + namespace1 + ".svc.clusterset.local. 5 IN SRV 0 50 " +
strconv.Itoa(int(portNumber1)) + " " + hostName1 + "." + service1 + "." + namespace1 + ".svc.clusterset.local."),
},
})
})
It("should succeed and write an SRV record response for query with cluster name", func() {
executeTestCase(lh, rec, test.Case{
Qname: clusterID + "." + service1 + "." + namespace1 + ".svc.clusterset.local.",
Qtype: dns.TypeSRV,
Rcode: dns.RcodeSuccess,
Answer: []dns.RR{
test.SRV(clusterID + "." + service1 + "." + namespace1 + ".svc.clusterset.local. 5 IN SRV 0 50 " +
strconv.Itoa(int(portNumber1)) + " " + hostName1 + "." + clusterID + "." + service1 + "." + namespace1 +
".svc.clusterset.local."),
},
})
})
})

When("headless service has two IPs", func() {
JustBeforeEach(func() {
lh.serviceImports.Put(newServiceImport(namespace1, service1, clusterID, "", portName1, portNumber1, protcol1,
mcsv1a1.Headless))
lh.endpointSlices.Put(newEndpointSlice(namespace1, service1, clusterID, []string{endpointIP, endpointIP2}))
lh.endpointSlices.Put(newEndpointSlice(namespace1, service1, clusterID, portName1, []string{hostName1, hostName2},
[]string{endpointIP, endpointIP2},
portNumber1, protcol1))
})
It("should succeed and write two A records as response", func() {
executeTestCase(lh, rec, test.Case{
Expand All @@ -677,6 +712,34 @@ func testHeadlessService() {
},
})
})
It("should succeed and write an SRV record response", func() {
executeTestCase(lh, rec, test.Case{
Qname: service1 + "." + namespace1 + ".svc.clusterset.local.",
Qtype: dns.TypeSRV,
Rcode: dns.RcodeSuccess,
Answer: []dns.RR{
test.SRV(service1 + "." + namespace1 + ".svc.clusterset.local. 5 IN SRV 0 50 " +
strconv.Itoa(int(portNumber1)) + " " + hostName1 + "." + service1 + "." + namespace1 + ".svc.clusterset.local."),
test.SRV(service1 + "." + namespace1 + ".svc.clusterset.local. 5 IN SRV 0 50 " +
strconv.Itoa(int(portNumber1)) + " " + hostName2 + "." + service1 + "." + namespace1 + ".svc.clusterset.local."),
},
})
})
It("should succeed and write an SRV record response when port and protocol is queried", func() {
executeTestCase(lh, rec, test.Case{
Qname: portName1 + "." + string(protcol1) + "." + service1 + "." + namespace1 + ".svc.clusterset.local.",
Qtype: dns.TypeSRV,
Rcode: dns.RcodeSuccess,
Answer: []dns.RR{
test.SRV(portName1 + "." + string(protcol1) + "." + service1 + "." + namespace1 + ".svc.clusterset.local." +
" 5 IN SRV 0 50 " + strconv.Itoa(int(portNumber1)) + " " + hostName1 + "." + service1 + "." +
namespace1 + ".svc.clusterset.local."),
test.SRV(portName1 + "." + string(protcol1) + "." + service1 + "." + namespace1 + ".svc.clusterset.local." +
" 5 IN SRV 0 50 " + strconv.Itoa(int(portNumber1)) + " " + hostName2 + "." + service1 + "." +
namespace1 + ".svc.clusterset.local."),
},
})
})
})

When("headless service is present in two clusters", func() {
Expand All @@ -685,8 +748,10 @@ func testHeadlessService() {
portNumber1, protcol1, mcsv1a1.Headless))
lh.serviceImports.Put(newServiceImport(namespace1, service1, clusterID2, "", portName1,
portNumber1, protcol1, mcsv1a1.Headless))
lh.endpointSlices.Put(newEndpointSlice(namespace1, service1, clusterID, []string{endpointIP}))
lh.endpointSlices.Put(newEndpointSlice(namespace1, service1, clusterID2, []string{endpointIP2}))
lh.endpointSlices.Put(newEndpointSlice(namespace1, service1, clusterID, portName1, []string{hostName1}, []string{endpointIP},
portNumber1, protcol1))
lh.endpointSlices.Put(newEndpointSlice(namespace1, service1, clusterID2, portName1, []string{hostName2}, []string{endpointIP2},
portNumber1, protcol1))
mockCs.clusterStatusMap[clusterID2] = true
})
When("no cluster is requested", func() {
Expand Down Expand Up @@ -966,7 +1031,7 @@ func setupServiceImportMap() *serviceimport.Map {

func setupEndpointSliceMap() *endpointslice.Map {
esMap := endpointslice.NewMap()
esMap.Put(newEndpointSlice(namespace1, service1, clusterID, []string{endpointIP}))
esMap.Put(newEndpointSlice(namespace1, service1, clusterID, portName1, []string{hostName1}, []string{endpointIP}, portNumber1, protcol1))

return esMap
}
Expand Down Expand Up @@ -1006,7 +1071,18 @@ func newServiceImport(namespace, name, clusterID, serviceIP, portName string,
}
}

func newEndpointSlice(namespace, name, clusterID string, endpointIPs []string) *discovery.EndpointSlice {
func newEndpointSlice(namespace, name, clusterID, portName string, hostName, endpointIPs []string, portNumber int32,
protocol v1.Protocol) *discovery.EndpointSlice {
endpoints := make([]discovery.Endpoint, len(endpointIPs))

for i := range endpointIPs {
endpoint := discovery.Endpoint{
Addresses: []string{endpointIPs[i]},
Hostname: &hostName[i],
}
endpoints[i] = endpoint
}

return &discovery.EndpointSlice{
ObjectMeta: metav1.ObjectMeta{
Name: name,
Expand All @@ -1020,9 +1096,12 @@ func newEndpointSlice(namespace, name, clusterID string, endpointIPs []string) *
},
},
AddressType: discovery.AddressTypeIPv4,
Endpoints: []discovery.Endpoint{
Endpoints: endpoints,
Ports: []discovery.EndpointPort{
{
Addresses: endpointIPs,
Name: &portName,
Protocol: &protocol,
Port: &portNumber,
},
},
}
Expand Down
Loading

0 comments on commit 95c852d

Please sign in to comment.