Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use state file for updating N+ upstreams #2897

Merged
merged 4 commits into from
Dec 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
148 changes: 87 additions & 61 deletions internal/mode/static/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import (
"context"
"errors"
"fmt"
"sync"
"time"
Expand Down Expand Up @@ -182,11 +183,11 @@

h.setLatestConfiguration(&cfg)

err = h.updateUpstreamServers(
ctx,
logger,
cfg,
)
if h.cfg.plus {
err = h.updateUpstreamServers(cfg)
} else {
err = h.updateNginxConf(ctx, cfg)
}
case state.ClusterStateChange:
h.version++
cfg := dataplane.BuildConfiguration(ctx, gr, h.cfg.serviceResolver, h.version)
Expand All @@ -198,10 +199,7 @@

h.setLatestConfiguration(&cfg)

err = h.updateNginxConf(
ctx,
cfg,
)
err = h.updateNginxConf(ctx, cfg)
}

var nginxReloadRes status.NginxReloadResult
Expand Down Expand Up @@ -306,7 +304,10 @@
}

// updateNginxConf updates nginx conf files and reloads nginx.
func (h *eventHandlerImpl) updateNginxConf(ctx context.Context, conf dataplane.Configuration) error {
func (h *eventHandlerImpl) updateNginxConf(
ctx context.Context,
conf dataplane.Configuration,
) error {
files := h.cfg.generator.Generate(conf)
if err := h.cfg.nginxFileMgr.ReplaceFiles(files); err != nil {
return fmt.Errorf("failed to replace NGINX configuration files: %w", err)
Expand All @@ -316,89 +317,114 @@
return fmt.Errorf("failed to reload NGINX: %w", err)
}

// If using NGINX Plus, update upstream servers using the API.
if err := h.updateUpstreamServers(conf); err != nil {
return fmt.Errorf("failed to update upstream servers: %w", err)
}

Check warning on line 323 in internal/mode/static/handler.go

View check run for this annotation

Codecov / codecov/patch

internal/mode/static/handler.go#L322-L323

Added lines #L322 - L323 were not covered by tests

return nil
}

// updateUpstreamServers is called only when endpoints have changed. It updates nginx conf files and then:
// - if using NGINX Plus, determines which servers have changed and uses the N+ API to update them;
// - otherwise if not using NGINX Plus, or an error was returned from the API, reloads nginx.
func (h *eventHandlerImpl) updateUpstreamServers(
ctx context.Context,
logger logr.Logger,
conf dataplane.Configuration,
) error {
isPlus := h.cfg.nginxRuntimeMgr.IsPlus()

files := h.cfg.generator.Generate(conf)
if err := h.cfg.nginxFileMgr.ReplaceFiles(files); err != nil {
return fmt.Errorf("failed to replace NGINX configuration files: %w", err)
// updateUpstreamServers determines which servers have changed and uses the NGINX Plus API to update them.
// Only applicable when using NGINX Plus.
func (h *eventHandlerImpl) updateUpstreamServers(conf dataplane.Configuration) error {
if !h.cfg.plus {
return nil
}

reload := func() error {
if err := h.cfg.nginxRuntimeMgr.Reload(ctx, conf.Version); err != nil {
return fmt.Errorf("failed to reload NGINX: %w", err)
}
prevUpstreams, prevStreamUpstreams, err := h.cfg.nginxRuntimeMgr.GetUpstreams()
if err != nil {
return fmt.Errorf("failed to get upstreams from API: %w", err)
}

return nil
type upstream struct {
name string
servers []ngxclient.UpstreamServer
}
var upstreams []upstream

if isPlus {
type upstream struct {
name string
servers []ngxclient.UpstreamServer
for _, u := range conf.Upstreams {
confUpstream := upstream{
name: u.Name,
servers: ngxConfig.ConvertEndpoints(u.Endpoints),
}
var upstreams []upstream

prevUpstreams, err := h.cfg.nginxRuntimeMgr.GetUpstreams()
if err != nil {
logger.Error(err, "failed to get upstreams from API, reloading configuration instead")
return reload()
if u, ok := prevUpstreams[confUpstream.name]; ok {
if !serversEqual(confUpstream.servers, u.Peers) {
upstreams = append(upstreams, confUpstream)
}
}
}

for _, u := range conf.Upstreams {
confUpstream := upstream{
name: u.Name,
servers: ngxConfig.ConvertEndpoints(u.Endpoints),
}
type streamUpstream struct {
name string
servers []ngxclient.StreamUpstreamServer
}
var streamUpstreams []streamUpstream

if u, ok := prevUpstreams[confUpstream.name]; ok {
if !serversEqual(confUpstream.servers, u.Peers) {
upstreams = append(upstreams, confUpstream)
}
}
for _, u := range conf.StreamUpstreams {
confUpstream := streamUpstream{
name: u.Name,
servers: ngxConfig.ConvertStreamEndpoints(u.Endpoints),
}

var reloadPlus bool
for _, upstream := range upstreams {
if err := h.cfg.nginxRuntimeMgr.UpdateHTTPServers(upstream.name, upstream.servers); err != nil {
logger.Error(
err, "couldn't update upstream via the API, reloading configuration instead",
"upstreamName", upstream.name,
)
reloadPlus = true
if u, ok := prevStreamUpstreams[confUpstream.name]; ok {
if !serversEqual(confUpstream.servers, u.Peers) {
streamUpstreams = append(streamUpstreams, confUpstream)
}
}
}

if !reloadPlus {
return nil
var updateErr error
for _, upstream := range upstreams {
if err := h.cfg.nginxRuntimeMgr.UpdateHTTPServers(upstream.name, upstream.servers); err != nil {
updateErr = errors.Join(updateErr, fmt.Errorf(
"couldn't update upstream %q via the API: %w", upstream.name, err))
}
}

return reload()
for _, upstream := range streamUpstreams {
if err := h.cfg.nginxRuntimeMgr.UpdateStreamServers(upstream.name, upstream.servers); err != nil {
updateErr = errors.Join(updateErr, fmt.Errorf(
"couldn't update stream upstream %q via the API: %w", upstream.name, err))
}
}

return updateErr
}

func serversEqual(newServers []ngxclient.UpstreamServer, oldServers []ngxclient.Peer) bool {
// serversEqual accepts lists of either UpstreamServer/Peer or StreamUpstreamServer/StreamPeer and determines
// if the server names within these lists are equal.
func serversEqual[
upstreamServer ngxclient.UpstreamServer | ngxclient.StreamUpstreamServer,
peer ngxclient.Peer | ngxclient.StreamPeer,
](newServers []upstreamServer, oldServers []peer) bool {
sjberman marked this conversation as resolved.
Show resolved Hide resolved
if len(newServers) != len(oldServers) {
return false
}

getServerVal := func(T any) string {
var server string
switch t := T.(type) {
case ngxclient.UpstreamServer:
server = t.Server
case ngxclient.StreamUpstreamServer:
server = t.Server
case ngxclient.Peer:
server = t.Server
case ngxclient.StreamPeer:
server = t.Server
}
return server
}

diff := make(map[string]struct{}, len(newServers))
for _, s := range newServers {
diff[s.Server] = struct{}{}
diff[getServerVal(s)] = struct{}{}
}

for _, s := range oldServers {
if _, ok := diff[s.Server]; !ok {
if _, ok := diff[getServerVal(s)]; !ok {
return false
}
}
Expand Down
122 changes: 81 additions & 41 deletions internal/mode/static/handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -423,20 +423,29 @@ var _ = Describe("eventHandler", func() {
},
},
}
fakeNginxRuntimeMgr.GetUpstreamsReturns(upstreams, nil)

streamUpstreams := ngxclient.StreamUpstreams{
"two": ngxclient.StreamUpstream{
Peers: []ngxclient.StreamPeer{
{Server: "server2"},
},
},
}

fakeNginxRuntimeMgr.GetUpstreamsReturns(upstreams, streamUpstreams, nil)
})

When("running NGINX Plus", func() {
It("should call the NGINX Plus API", func() {
fakeNginxRuntimeMgr.IsPlusReturns(true)
handler.cfg.plus = true

handler.HandleEventBatch(context.Background(), ctlrZap.New(), batch)

dcfg := dataplane.GetDefaultConfiguration(&graph.Graph{}, 1)
Expect(helpers.Diff(handler.GetLatestConfiguration(), &dcfg)).To(BeEmpty())

Expect(fakeGenerator.GenerateCallCount()).To(Equal(1))
Expect(fakeNginxFileMgr.ReplaceFilesCallCount()).To(Equal(1))
Expect(fakeGenerator.GenerateCallCount()).To(Equal(0))
Expect(fakeNginxFileMgr.ReplaceFilesCallCount()).To(Equal(0))
Expect(fakeNginxRuntimeMgr.GetUpstreamsCallCount()).To(Equal(1))
})
})
Expand All @@ -463,19 +472,11 @@ var _ = Describe("eventHandler", func() {
Name: "one",
},
},
}

type callCounts struct {
generate int
update int
reload int
}

assertCallCounts := func(cc callCounts) {
Expect(fakeGenerator.GenerateCallCount()).To(Equal(cc.generate))
Expect(fakeNginxFileMgr.ReplaceFilesCallCount()).To(Equal(cc.generate))
Expect(fakeNginxRuntimeMgr.UpdateHTTPServersCallCount()).To(Equal(cc.update))
Expect(fakeNginxRuntimeMgr.ReloadCallCount()).To(Equal(cc.reload))
StreamUpstreams: []dataplane.Upstream{
{
Name: "two",
},
},
}

BeforeEach(func() {
Expand All @@ -486,47 +487,49 @@ var _ = Describe("eventHandler", func() {
},
},
}
fakeNginxRuntimeMgr.GetUpstreamsReturns(upstreams, nil)

streamUpstreams := ngxclient.StreamUpstreams{
"two": ngxclient.StreamUpstream{
Peers: []ngxclient.StreamPeer{
{Server: "server2"},
},
},
}

fakeNginxRuntimeMgr.GetUpstreamsReturns(upstreams, streamUpstreams, nil)
})

When("running NGINX Plus", func() {
BeforeEach(func() {
fakeNginxRuntimeMgr.IsPlusReturns(true)
handler.cfg.plus = true
})

It("should update servers using the NGINX Plus API", func() {
Expect(handler.updateUpstreamServers(context.Background(), ctlrZap.New(), conf)).To(Succeed())

assertCallCounts(callCounts{generate: 1, update: 1, reload: 0})
Expect(handler.updateUpstreamServers(conf)).To(Succeed())
Expect(fakeNginxRuntimeMgr.UpdateHTTPServersCallCount()).To(Equal(1))
})

It("should reload when GET API returns an error", func() {
fakeNginxRuntimeMgr.GetUpstreamsReturns(nil, errors.New("error"))
Expect(handler.updateUpstreamServers(context.Background(), ctlrZap.New(), conf)).To(Succeed())

assertCallCounts(callCounts{generate: 1, update: 0, reload: 1})
It("should return error when GET API returns an error", func() {
fakeNginxRuntimeMgr.GetUpstreamsReturns(nil, nil, errors.New("error"))
Expect(handler.updateUpstreamServers(conf)).ToNot(Succeed())
})

It("should reload when POST API returns an error", func() {
It("should return error when UpdateHTTPServers API returns an error", func() {
fakeNginxRuntimeMgr.UpdateHTTPServersReturns(errors.New("error"))
Expect(handler.updateUpstreamServers(context.Background(), ctlrZap.New(), conf)).To(Succeed())
Expect(handler.updateUpstreamServers(conf)).ToNot(Succeed())
})

assertCallCounts(callCounts{generate: 1, update: 1, reload: 1})
It("should return error when UpdateStreamServers API returns an error", func() {
fakeNginxRuntimeMgr.UpdateStreamServersReturns(errors.New("error"))
Expect(handler.updateUpstreamServers(conf)).ToNot(Succeed())
})
})

When("not running NGINX Plus", func() {
It("should update servers by reloading", func() {
Expect(handler.updateUpstreamServers(context.Background(), ctlrZap.New(), conf)).To(Succeed())

assertCallCounts(callCounts{generate: 1, update: 0, reload: 1})
})
It("should not do anything", func() {
Expect(handler.updateUpstreamServers(conf)).To(Succeed())

It("should return an error when reloading fails", func() {
fakeNginxRuntimeMgr.ReloadReturns(errors.New("error"))
Expect(handler.updateUpstreamServers(context.Background(), ctlrZap.New(), conf)).ToNot(Succeed())

assertCallCounts(callCounts{generate: 1, update: 0, reload: 1})
Expect(fakeNginxRuntimeMgr.UpdateHTTPServersCallCount()).To(Equal(0))
})
})
})
Expand Down Expand Up @@ -612,7 +615,7 @@ var _ = Describe("eventHandler", func() {
})

var _ = Describe("serversEqual", func() {
DescribeTable("determines if server lists are equal",
DescribeTable("determines if HTTP server lists are equal",
func(newServers []ngxclient.UpstreamServer, oldServers []ngxclient.Peer, equal bool) {
Expect(serversEqual(newServers, oldServers)).To(Equal(equal))
},
Expand Down Expand Up @@ -649,6 +652,43 @@ var _ = Describe("serversEqual", func() {
true,
),
)
DescribeTable("determines if stream server lists are equal",
func(newServers []ngxclient.StreamUpstreamServer, oldServers []ngxclient.StreamPeer, equal bool) {
Expect(serversEqual(newServers, oldServers)).To(Equal(equal))
},
Entry("different length",
[]ngxclient.StreamUpstreamServer{
{Server: "server1"},
},
[]ngxclient.StreamPeer{
{Server: "server1"},
{Server: "server2"},
},
false,
),
Entry("differing elements",
[]ngxclient.StreamUpstreamServer{
{Server: "server1"},
{Server: "server2"},
},
[]ngxclient.StreamPeer{
{Server: "server1"},
{Server: "server3"},
},
false,
),
Entry("same elements",
[]ngxclient.StreamUpstreamServer{
{Server: "server1"},
{Server: "server2"},
},
[]ngxclient.StreamPeer{
{Server: "server1"},
{Server: "server2"},
},
true,
),
)
})

var _ = Describe("getGatewayAddresses", func() {
Expand Down
Loading
Loading