diff --git a/internal/mode/static/handler.go b/internal/mode/static/handler.go index e09732e7f0..2a835b8584 100644 --- a/internal/mode/static/handler.go +++ b/internal/mode/static/handler.go @@ -2,6 +2,7 @@ package static import ( "context" + "errors" "fmt" "sync" "time" @@ -182,11 +183,11 @@ func (h *eventHandlerImpl) HandleEventBatch(ctx context.Context, logger logr.Log h.setLatestConfiguration(&cfg) - err = h.updateUpstreamServers( - ctx, - logger, - cfg, - ) + if h.cfg.plus { + err = h.updateUpstreamServers(cfg) + } else { + err = h.updateNginxConf(ctx, cfg) + } case state.ClusterStateChange: h.version++ cfg := dataplane.BuildConfiguration(ctx, gr, h.cfg.serviceResolver, h.version) @@ -198,10 +199,7 @@ func (h *eventHandlerImpl) HandleEventBatch(ctx context.Context, logger logr.Log h.setLatestConfiguration(&cfg) - err = h.updateNginxConf( - ctx, - cfg, - ) + err = h.updateNginxConf(ctx, cfg) } var nginxReloadRes status.NginxReloadResult @@ -306,7 +304,10 @@ func (h *eventHandlerImpl) parseAndCaptureEvent(ctx context.Context, logger logr } // updateNginxConf updates nginx conf files and reloads nginx. -func (h *eventHandlerImpl) updateNginxConf(ctx context.Context, conf dataplane.Configuration) error { +func (h *eventHandlerImpl) updateNginxConf( + ctx context.Context, + conf dataplane.Configuration, +) error { files := h.cfg.generator.Generate(conf) if err := h.cfg.nginxFileMgr.ReplaceFiles(files); err != nil { return fmt.Errorf("failed to replace NGINX configuration files: %w", err) @@ -316,89 +317,114 @@ func (h *eventHandlerImpl) updateNginxConf(ctx context.Context, conf dataplane.C return fmt.Errorf("failed to reload NGINX: %w", err) } + // If using NGINX Plus, update upstream servers using the API. + if err := h.updateUpstreamServers(conf); err != nil { + return fmt.Errorf("failed to update upstream servers: %w", err) + } + return nil } -// updateUpstreamServers is called only when endpoints have changed. It updates nginx conf files and then: -// - if using NGINX Plus, determines which servers have changed and uses the N+ API to update them; -// - otherwise if not using NGINX Plus, or an error was returned from the API, reloads nginx. -func (h *eventHandlerImpl) updateUpstreamServers( - ctx context.Context, - logger logr.Logger, - conf dataplane.Configuration, -) error { - isPlus := h.cfg.nginxRuntimeMgr.IsPlus() - - files := h.cfg.generator.Generate(conf) - if err := h.cfg.nginxFileMgr.ReplaceFiles(files); err != nil { - return fmt.Errorf("failed to replace NGINX configuration files: %w", err) +// updateUpstreamServers determines which servers have changed and uses the NGINX Plus API to update them. +// Only applicable when using NGINX Plus. +func (h *eventHandlerImpl) updateUpstreamServers(conf dataplane.Configuration) error { + if !h.cfg.plus { + return nil } - reload := func() error { - if err := h.cfg.nginxRuntimeMgr.Reload(ctx, conf.Version); err != nil { - return fmt.Errorf("failed to reload NGINX: %w", err) - } + prevUpstreams, prevStreamUpstreams, err := h.cfg.nginxRuntimeMgr.GetUpstreams() + if err != nil { + return fmt.Errorf("failed to get upstreams from API: %w", err) + } - return nil + type upstream struct { + name string + servers []ngxclient.UpstreamServer } + var upstreams []upstream - if isPlus { - type upstream struct { - name string - servers []ngxclient.UpstreamServer + for _, u := range conf.Upstreams { + confUpstream := upstream{ + name: u.Name, + servers: ngxConfig.ConvertEndpoints(u.Endpoints), } - var upstreams []upstream - prevUpstreams, err := h.cfg.nginxRuntimeMgr.GetUpstreams() - if err != nil { - logger.Error(err, "failed to get upstreams from API, reloading configuration instead") - return reload() + if u, ok := prevUpstreams[confUpstream.name]; ok { + if !serversEqual(confUpstream.servers, u.Peers) { + upstreams = append(upstreams, confUpstream) + } } + } - for _, u := range conf.Upstreams { - confUpstream := upstream{ - name: u.Name, - servers: ngxConfig.ConvertEndpoints(u.Endpoints), - } + type streamUpstream struct { + name string + servers []ngxclient.StreamUpstreamServer + } + var streamUpstreams []streamUpstream - if u, ok := prevUpstreams[confUpstream.name]; ok { - if !serversEqual(confUpstream.servers, u.Peers) { - upstreams = append(upstreams, confUpstream) - } - } + for _, u := range conf.StreamUpstreams { + confUpstream := streamUpstream{ + name: u.Name, + servers: ngxConfig.ConvertStreamEndpoints(u.Endpoints), } - var reloadPlus bool - for _, upstream := range upstreams { - if err := h.cfg.nginxRuntimeMgr.UpdateHTTPServers(upstream.name, upstream.servers); err != nil { - logger.Error( - err, "couldn't update upstream via the API, reloading configuration instead", - "upstreamName", upstream.name, - ) - reloadPlus = true + if u, ok := prevStreamUpstreams[confUpstream.name]; ok { + if !serversEqual(confUpstream.servers, u.Peers) { + streamUpstreams = append(streamUpstreams, confUpstream) } } + } - if !reloadPlus { - return nil + var updateErr error + for _, upstream := range upstreams { + if err := h.cfg.nginxRuntimeMgr.UpdateHTTPServers(upstream.name, upstream.servers); err != nil { + updateErr = errors.Join(updateErr, fmt.Errorf( + "couldn't update upstream %q via the API: %w", upstream.name, err)) } } - return reload() + for _, upstream := range streamUpstreams { + if err := h.cfg.nginxRuntimeMgr.UpdateStreamServers(upstream.name, upstream.servers); err != nil { + updateErr = errors.Join(updateErr, fmt.Errorf( + "couldn't update stream upstream %q via the API: %w", upstream.name, err)) + } + } + + return updateErr } -func serversEqual(newServers []ngxclient.UpstreamServer, oldServers []ngxclient.Peer) bool { +// serversEqual accepts lists of either UpstreamServer/Peer or StreamUpstreamServer/StreamPeer and determines +// if the server names within these lists are equal. +func serversEqual[ + upstreamServer ngxclient.UpstreamServer | ngxclient.StreamUpstreamServer, + peer ngxclient.Peer | ngxclient.StreamPeer, +](newServers []upstreamServer, oldServers []peer) bool { if len(newServers) != len(oldServers) { return false } + getServerVal := func(T any) string { + var server string + switch t := T.(type) { + case ngxclient.UpstreamServer: + server = t.Server + case ngxclient.StreamUpstreamServer: + server = t.Server + case ngxclient.Peer: + server = t.Server + case ngxclient.StreamPeer: + server = t.Server + } + return server + } + diff := make(map[string]struct{}, len(newServers)) for _, s := range newServers { - diff[s.Server] = struct{}{} + diff[getServerVal(s)] = struct{}{} } for _, s := range oldServers { - if _, ok := diff[s.Server]; !ok { + if _, ok := diff[getServerVal(s)]; !ok { return false } } diff --git a/internal/mode/static/handler_test.go b/internal/mode/static/handler_test.go index 67bf0e8e0e..c24f5d27d2 100644 --- a/internal/mode/static/handler_test.go +++ b/internal/mode/static/handler_test.go @@ -423,20 +423,29 @@ var _ = Describe("eventHandler", func() { }, }, } - fakeNginxRuntimeMgr.GetUpstreamsReturns(upstreams, nil) + + streamUpstreams := ngxclient.StreamUpstreams{ + "two": ngxclient.StreamUpstream{ + Peers: []ngxclient.StreamPeer{ + {Server: "server2"}, + }, + }, + } + + fakeNginxRuntimeMgr.GetUpstreamsReturns(upstreams, streamUpstreams, nil) }) When("running NGINX Plus", func() { It("should call the NGINX Plus API", func() { - fakeNginxRuntimeMgr.IsPlusReturns(true) + handler.cfg.plus = true handler.HandleEventBatch(context.Background(), ctlrZap.New(), batch) dcfg := dataplane.GetDefaultConfiguration(&graph.Graph{}, 1) Expect(helpers.Diff(handler.GetLatestConfiguration(), &dcfg)).To(BeEmpty()) - Expect(fakeGenerator.GenerateCallCount()).To(Equal(1)) - Expect(fakeNginxFileMgr.ReplaceFilesCallCount()).To(Equal(1)) + Expect(fakeGenerator.GenerateCallCount()).To(Equal(0)) + Expect(fakeNginxFileMgr.ReplaceFilesCallCount()).To(Equal(0)) Expect(fakeNginxRuntimeMgr.GetUpstreamsCallCount()).To(Equal(1)) }) }) @@ -463,19 +472,11 @@ var _ = Describe("eventHandler", func() { Name: "one", }, }, - } - - type callCounts struct { - generate int - update int - reload int - } - - assertCallCounts := func(cc callCounts) { - Expect(fakeGenerator.GenerateCallCount()).To(Equal(cc.generate)) - Expect(fakeNginxFileMgr.ReplaceFilesCallCount()).To(Equal(cc.generate)) - Expect(fakeNginxRuntimeMgr.UpdateHTTPServersCallCount()).To(Equal(cc.update)) - Expect(fakeNginxRuntimeMgr.ReloadCallCount()).To(Equal(cc.reload)) + StreamUpstreams: []dataplane.Upstream{ + { + Name: "two", + }, + }, } BeforeEach(func() { @@ -486,47 +487,49 @@ var _ = Describe("eventHandler", func() { }, }, } - fakeNginxRuntimeMgr.GetUpstreamsReturns(upstreams, nil) + + streamUpstreams := ngxclient.StreamUpstreams{ + "two": ngxclient.StreamUpstream{ + Peers: []ngxclient.StreamPeer{ + {Server: "server2"}, + }, + }, + } + + fakeNginxRuntimeMgr.GetUpstreamsReturns(upstreams, streamUpstreams, nil) }) When("running NGINX Plus", func() { BeforeEach(func() { - fakeNginxRuntimeMgr.IsPlusReturns(true) + handler.cfg.plus = true }) It("should update servers using the NGINX Plus API", func() { - Expect(handler.updateUpstreamServers(context.Background(), ctlrZap.New(), conf)).To(Succeed()) - - assertCallCounts(callCounts{generate: 1, update: 1, reload: 0}) + Expect(handler.updateUpstreamServers(conf)).To(Succeed()) + Expect(fakeNginxRuntimeMgr.UpdateHTTPServersCallCount()).To(Equal(1)) }) - It("should reload when GET API returns an error", func() { - fakeNginxRuntimeMgr.GetUpstreamsReturns(nil, errors.New("error")) - Expect(handler.updateUpstreamServers(context.Background(), ctlrZap.New(), conf)).To(Succeed()) - - assertCallCounts(callCounts{generate: 1, update: 0, reload: 1}) + It("should return error when GET API returns an error", func() { + fakeNginxRuntimeMgr.GetUpstreamsReturns(nil, nil, errors.New("error")) + Expect(handler.updateUpstreamServers(conf)).ToNot(Succeed()) }) - It("should reload when POST API returns an error", func() { + It("should return error when UpdateHTTPServers API returns an error", func() { fakeNginxRuntimeMgr.UpdateHTTPServersReturns(errors.New("error")) - Expect(handler.updateUpstreamServers(context.Background(), ctlrZap.New(), conf)).To(Succeed()) + Expect(handler.updateUpstreamServers(conf)).ToNot(Succeed()) + }) - assertCallCounts(callCounts{generate: 1, update: 1, reload: 1}) + It("should return error when UpdateStreamServers API returns an error", func() { + fakeNginxRuntimeMgr.UpdateStreamServersReturns(errors.New("error")) + Expect(handler.updateUpstreamServers(conf)).ToNot(Succeed()) }) }) When("not running NGINX Plus", func() { - It("should update servers by reloading", func() { - Expect(handler.updateUpstreamServers(context.Background(), ctlrZap.New(), conf)).To(Succeed()) - - assertCallCounts(callCounts{generate: 1, update: 0, reload: 1}) - }) + It("should not do anything", func() { + Expect(handler.updateUpstreamServers(conf)).To(Succeed()) - It("should return an error when reloading fails", func() { - fakeNginxRuntimeMgr.ReloadReturns(errors.New("error")) - Expect(handler.updateUpstreamServers(context.Background(), ctlrZap.New(), conf)).ToNot(Succeed()) - - assertCallCounts(callCounts{generate: 1, update: 0, reload: 1}) + Expect(fakeNginxRuntimeMgr.UpdateHTTPServersCallCount()).To(Equal(0)) }) }) }) @@ -612,7 +615,7 @@ var _ = Describe("eventHandler", func() { }) var _ = Describe("serversEqual", func() { - DescribeTable("determines if server lists are equal", + DescribeTable("determines if HTTP server lists are equal", func(newServers []ngxclient.UpstreamServer, oldServers []ngxclient.Peer, equal bool) { Expect(serversEqual(newServers, oldServers)).To(Equal(equal)) }, @@ -649,6 +652,43 @@ var _ = Describe("serversEqual", func() { true, ), ) + DescribeTable("determines if stream server lists are equal", + func(newServers []ngxclient.StreamUpstreamServer, oldServers []ngxclient.StreamPeer, equal bool) { + Expect(serversEqual(newServers, oldServers)).To(Equal(equal)) + }, + Entry("different length", + []ngxclient.StreamUpstreamServer{ + {Server: "server1"}, + }, + []ngxclient.StreamPeer{ + {Server: "server1"}, + {Server: "server2"}, + }, + false, + ), + Entry("differing elements", + []ngxclient.StreamUpstreamServer{ + {Server: "server1"}, + {Server: "server2"}, + }, + []ngxclient.StreamPeer{ + {Server: "server1"}, + {Server: "server3"}, + }, + false, + ), + Entry("same elements", + []ngxclient.StreamUpstreamServer{ + {Server: "server1"}, + {Server: "server2"}, + }, + []ngxclient.StreamPeer{ + {Server: "server1"}, + {Server: "server2"}, + }, + true, + ), + ) }) var _ = Describe("getGatewayAddresses", func() { diff --git a/internal/mode/static/manager.go b/internal/mode/static/manager.go index bc24210318..bc94e61346 100644 --- a/internal/mode/static/manager.go +++ b/internal/mode/static/manager.go @@ -172,15 +172,17 @@ func StartManager(cfg config.Config) error { ) var ngxPlusClient ngxruntime.NginxPlusClient + if cfg.Plus { + ngxPlusClient, err = ngxruntime.CreatePlusClient() + if err != nil { + return fmt.Errorf("error creating NGINX plus client: %w", err) + } + } if cfg.MetricsConfig.Enabled { constLabels := map[string]string{"class": cfg.GatewayClassName} var ngxCollector prometheus.Collector if cfg.Plus { - ngxPlusClient, err = ngxruntime.CreatePlusClient() - if err != nil { - return fmt.Errorf("error creating NGINX plus client: %w", err) - } ngxCollector, err = collectors.NewNginxPlusMetricsCollector(ngxPlusClient, constLabels, promLogger) } else { ngxCollector = collectors.NewNginxMetricsCollector(constLabels, promLogger) diff --git a/internal/mode/static/nginx/config/convert.go b/internal/mode/static/nginx/config/convert.go index ff20bf888d..3038149a0e 100644 --- a/internal/mode/static/nginx/config/convert.go +++ b/internal/mode/static/nginx/config/convert.go @@ -13,17 +13,26 @@ func ConvertEndpoints(eps []resolver.Endpoint) []ngxclient.UpstreamServer { servers := make([]ngxclient.UpstreamServer, 0, len(eps)) for _, ep := range eps { - var port string - if ep.Port != 0 { - port = fmt.Sprintf(":%d", ep.Port) - } + port, format := getPortAndIPFormat(ep) - format := "%s%s" - if ep.IPv6 { - format = "[%s]%s" + server := ngxclient.UpstreamServer{ + Server: fmt.Sprintf(format, ep.Address, port), } - server := ngxclient.UpstreamServer{ + servers = append(servers, server) + } + + return servers +} + +// ConvertStreamEndpoints converts a list of Endpoints into a list of NGINX Plus SDK StreamUpstreamServers. +func ConvertStreamEndpoints(eps []resolver.Endpoint) []ngxclient.StreamUpstreamServer { + servers := make([]ngxclient.StreamUpstreamServer, 0, len(eps)) + + for _, ep := range eps { + port, format := getPortAndIPFormat(ep) + + server := ngxclient.StreamUpstreamServer{ Server: fmt.Sprintf(format, ep.Address, port), } @@ -32,3 +41,18 @@ func ConvertEndpoints(eps []resolver.Endpoint) []ngxclient.UpstreamServer { return servers } + +func getPortAndIPFormat(ep resolver.Endpoint) (string, string) { + var port string + + if ep.Port != 0 { + port = fmt.Sprintf(":%d", ep.Port) + } + + format := "%s%s" + if ep.IPv6 { + format = "[%s]%s" + } + + return port, format +} diff --git a/internal/mode/static/nginx/config/convert_test.go b/internal/mode/static/nginx/config/convert_test.go index 6be41ccda6..68520dfd78 100644 --- a/internal/mode/static/nginx/config/convert_test.go +++ b/internal/mode/static/nginx/config/convert_test.go @@ -42,3 +42,37 @@ func TestConvertEndpoints(t *testing.T) { g := NewWithT(t) g.Expect(ConvertEndpoints(endpoints)).To(Equal(expUpstreams)) } + +func TestConvertStreamEndpoints(t *testing.T) { + t.Parallel() + endpoints := []resolver.Endpoint{ + { + Address: "1.2.3.4", + Port: 80, + }, + { + Address: "5.6.7.8", + Port: 0, + }, + { + Address: "2001:db8::1", + Port: 443, + IPv6: true, + }, + } + + expUpstreams := []ngxclient.StreamUpstreamServer{ + { + Server: "1.2.3.4:80", + }, + { + Server: "5.6.7.8", + }, + { + Server: "[2001:db8::1]:443", + }, + } + + g := NewWithT(t) + g.Expect(ConvertStreamEndpoints(endpoints)).To(Equal(expUpstreams)) +} diff --git a/internal/mode/static/nginx/config/http/config.go b/internal/mode/static/nginx/config/http/config.go index 24aecaa3e4..6d063dc8a7 100644 --- a/internal/mode/static/nginx/config/http/config.go +++ b/internal/mode/static/nginx/config/http/config.go @@ -82,9 +82,10 @@ const ( // Upstream holds all configuration for an HTTP upstream. type Upstream struct { - Name string - ZoneSize string // format: 512k, 1m - Servers []UpstreamServer + Name string + ZoneSize string // format: 512k, 1m + StateFile string + Servers []UpstreamServer } // UpstreamServer holds all configuration for an HTTP upstream server. diff --git a/internal/mode/static/nginx/config/stream/config.go b/internal/mode/static/nginx/config/stream/config.go index ddc215eea7..1202c1ec85 100644 --- a/internal/mode/static/nginx/config/stream/config.go +++ b/internal/mode/static/nginx/config/stream/config.go @@ -15,9 +15,10 @@ type Server struct { // Upstream holds all configuration for a stream upstream. type Upstream struct { - Name string - ZoneSize string // format: 512k, 1m - Servers []UpstreamServer + Name string + ZoneSize string // format: 512k, 1m + StateFile string + Servers []UpstreamServer } // UpstreamServer holds all configuration for a stream upstream server. diff --git a/internal/mode/static/nginx/config/upstreams.go b/internal/mode/static/nginx/config/upstreams.go index 88c66c47fd..51af6f4f4b 100644 --- a/internal/mode/static/nginx/config/upstreams.go +++ b/internal/mode/static/nginx/config/upstreams.go @@ -27,6 +27,8 @@ const ( ossZoneSizeStream = "512k" // plusZoneSize is the upstream zone size for nginx plus. plusZoneSizeStream = "1m" + // stateDir is the directory for storing state files. + stateDir = "/var/lib/nginx/state" ) func (g GeneratorImpl) executeUpstreams(conf dataplane.Configuration) []executeResult { @@ -64,9 +66,11 @@ func (g GeneratorImpl) createStreamUpstreams(upstreams []dataplane.Upstream) []s } func (g GeneratorImpl) createStreamUpstream(up dataplane.Upstream) stream.Upstream { + var stateFile string zoneSize := ossZoneSizeStream if g.plus { zoneSize = plusZoneSizeStream + stateFile = fmt.Sprintf("%s/%s.conf", stateDir, up.Name) } upstreamServers := make([]stream.UpstreamServer, len(up.Endpoints)) @@ -81,9 +85,10 @@ func (g GeneratorImpl) createStreamUpstream(up dataplane.Upstream) stream.Upstre } return stream.Upstream{ - Name: up.Name, - ZoneSize: zoneSize, - Servers: upstreamServers, + Name: up.Name, + ZoneSize: zoneSize, + StateFile: stateFile, + Servers: upstreamServers, } } @@ -101,15 +106,18 @@ func (g GeneratorImpl) createUpstreams(upstreams []dataplane.Upstream) []http.Up } func (g GeneratorImpl) createUpstream(up dataplane.Upstream) http.Upstream { + var stateFile string zoneSize := ossZoneSize if g.plus { zoneSize = plusZoneSize + stateFile = fmt.Sprintf("%s/%s.conf", stateDir, up.Name) } if len(up.Endpoints) == 0 { return http.Upstream{ - Name: up.Name, - ZoneSize: zoneSize, + Name: up.Name, + ZoneSize: zoneSize, + StateFile: stateFile, Servers: []http.UpstreamServer{ { Address: nginx503Server, @@ -130,9 +138,10 @@ func (g GeneratorImpl) createUpstream(up dataplane.Upstream) http.Upstream { } return http.Upstream{ - Name: up.Name, - ZoneSize: zoneSize, - Servers: upstreamServers, + Name: up.Name, + ZoneSize: zoneSize, + StateFile: stateFile, + Servers: upstreamServers, } } diff --git a/internal/mode/static/nginx/config/upstreams_template.go b/internal/mode/static/nginx/config/upstreams_template.go index a04915bec8..40d5740ad0 100644 --- a/internal/mode/static/nginx/config/upstreams_template.go +++ b/internal/mode/static/nginx/config/upstreams_template.go @@ -12,8 +12,13 @@ upstream {{ $u.Name }} { {{ if $u.ZoneSize -}} zone {{ $u.Name }} {{ $u.ZoneSize }}; {{ end -}} - {{ range $server := $u.Servers }} + + {{- if $u.StateFile }} + state {{ $u.StateFile }}; + {{- else }} + {{ range $server := $u.Servers }} server {{ $server.Address }}; + {{- end }} {{- end }} } {{ end -}} diff --git a/internal/mode/static/nginx/config/upstreams_test.go b/internal/mode/static/nginx/config/upstreams_test.go index 5b3a8268a3..f2e5b1071b 100644 --- a/internal/mode/static/nginx/config/upstreams_test.go +++ b/internal/mode/static/nginx/config/upstreams_test.go @@ -289,29 +289,60 @@ func TestCreateUpstreamPlus(t *testing.T) { t.Parallel() gen := GeneratorImpl{plus: true} - stateUpstream := dataplane.Upstream{ - Name: "multiple-endpoints", - Endpoints: []resolver.Endpoint{ - { - Address: "10.0.0.1", - Port: 80, + tests := []struct { + msg string + stateUpstream dataplane.Upstream + expectedUpstream http.Upstream + }{ + { + msg: "with endpoints", + stateUpstream: dataplane.Upstream{ + Name: "endpoints", + Endpoints: []resolver.Endpoint{ + { + Address: "10.0.0.1", + Port: 80, + }, + }, + }, + expectedUpstream: http.Upstream{ + Name: "endpoints", + ZoneSize: plusZoneSize, + StateFile: stateDir + "/endpoints.conf", + Servers: []http.UpstreamServer{ + { + Address: "10.0.0.1:80", + }, + }, }, }, - } - expectedUpstream := http.Upstream{ - Name: "multiple-endpoints", - ZoneSize: plusZoneSize, - Servers: []http.UpstreamServer{ - { - Address: "10.0.0.1:80", + { + msg: "no endpoints", + stateUpstream: dataplane.Upstream{ + Name: "no-endpoints", + Endpoints: []resolver.Endpoint{}, + }, + expectedUpstream: http.Upstream{ + Name: "no-endpoints", + ZoneSize: plusZoneSize, + StateFile: stateDir + "/no-endpoints.conf", + Servers: []http.UpstreamServer{ + { + Address: nginx503Server, + }, + }, }, }, } - result := gen.createUpstream(stateUpstream) - - g := NewWithT(t) - g.Expect(result).To(Equal(expectedUpstream)) + for _, test := range tests { + t.Run(test.msg, func(t *testing.T) { + t.Parallel() + g := NewWithT(t) + result := gen.createUpstream(test.stateUpstream) + g.Expect(result).To(Equal(test.expectedUpstream)) + }) + } } func TestExecuteStreamUpstreams(t *testing.T) { @@ -491,8 +522,9 @@ func TestCreateStreamUpstreamPlus(t *testing.T) { }, } expectedUpstream := stream.Upstream{ - Name: "multiple-endpoints", - ZoneSize: plusZoneSize, + Name: "multiple-endpoints", + ZoneSize: plusZoneSize, + StateFile: stateDir + "/multiple-endpoints.conf", Servers: []stream.UpstreamServer{ { Address: "10.0.0.1:80", diff --git a/internal/mode/static/nginx/runtime/manager.go b/internal/mode/static/nginx/runtime/manager.go index 45d24dbb75..afa641645f 100644 --- a/internal/mode/static/nginx/runtime/manager.go +++ b/internal/mode/static/nginx/runtime/manager.go @@ -47,6 +47,16 @@ type NginxPlusClient interface { err error, ) GetUpstreams() (*ngxclient.Upstreams, error) + UpdateStreamServers( + upstream string, + servers []ngxclient.StreamUpstreamServer, + ) ( + added []ngxclient.StreamUpstreamServer, + deleted []ngxclient.StreamUpstreamServer, + updated []ngxclient.StreamUpstreamServer, + err error, + ) + GetStreamUpstreams() (*ngxclient.StreamUpstreams, error) } //counterfeiter:generate . Manager @@ -57,12 +67,15 @@ type Manager interface { Reload(ctx context.Context, configVersion int) error // IsPlus returns whether or not we are running NGINX plus. IsPlus() bool - // UpdateHTTPServers uses the NGINX Plus API to update HTTP servers. + // GetUpstreams uses the NGINX Plus API to get the upstreams. + // Only usable if running NGINX Plus. + GetUpstreams() (ngxclient.Upstreams, ngxclient.StreamUpstreams, error) + // UpdateHTTPServers uses the NGINX Plus API to update HTTP upstream servers. // Only usable if running NGINX Plus. UpdateHTTPServers(string, []ngxclient.UpstreamServer) error - // GetUpstreams uses the NGINX Plus API to get the upstreams. + // UpdateStreamServers uses the NGINX Plus API to update stream upstream servers. // Only usable if running NGINX Plus. - GetUpstreams() (ngxclient.Upstreams, error) + UpdateStreamServers(string, []ngxclient.StreamUpstreamServer) error } // MetricsCollector is an interface for the metrics of the NGINX runtime manager. @@ -143,6 +156,34 @@ func (m *ManagerImpl) Reload(ctx context.Context, configVersion int) error { return nil } +// GetUpstreams uses the NGINX Plus API to get the upstreams. +// Only usable if running NGINX Plus. +func (m *ManagerImpl) GetUpstreams() (ngxclient.Upstreams, ngxclient.StreamUpstreams, error) { + if !m.IsPlus() { + panic("cannot get upstream servers: NGINX Plus not enabled") + } + + upstreams, err := m.ngxPlusClient.GetUpstreams() + if err != nil { + return nil, nil, err + } + + if upstreams == nil { + return nil, nil, errors.New("GET upstreams returned nil value") + } + + streamUpstreams, err := m.ngxPlusClient.GetStreamUpstreams() + if err != nil { + return nil, nil, err + } + + if streamUpstreams == nil { + return nil, nil, errors.New("GET stream upstreams returned nil value") + } + + return *upstreams, *streamUpstreams, nil +} + // UpdateHTTPServers uses the NGINX Plus API to update HTTP upstream servers. // Only usable if running NGINX Plus. func (m *ManagerImpl) UpdateHTTPServers(upstream string, servers []ngxclient.UpstreamServer) error { @@ -158,23 +199,19 @@ func (m *ManagerImpl) UpdateHTTPServers(upstream string, servers []ngxclient.Ups return err } -// GetUpstreams uses the NGINX Plus API to get the upstreams. +// UpdateStreamServers uses the NGINX Plus API to update stream upstream servers. // Only usable if running NGINX Plus. -func (m *ManagerImpl) GetUpstreams() (ngxclient.Upstreams, error) { +func (m *ManagerImpl) UpdateStreamServers(upstream string, servers []ngxclient.StreamUpstreamServer) error { if !m.IsPlus() { - panic("cannot get HTTP upstream servers: NGINX Plus not enabled") - } - - upstreams, err := m.ngxPlusClient.GetUpstreams() - if err != nil { - return nil, err + panic("cannot update stream upstream servers: NGINX Plus not enabled") } - if upstreams == nil { - return nil, errors.New("GET upstreams returned nil value") - } + added, deleted, updated, err := m.ngxPlusClient.UpdateStreamServers(upstream, servers) + m.logger.V(1).Info("Added stream upstream servers", "count", len(added)) + m.logger.V(1).Info("Deleted stream upstream servers", "count", len(deleted)) + m.logger.V(1).Info("Updated stream upstream servers", "count", len(updated)) - return *upstreams, nil + return err } //counterfeiter:generate . ProcessHandler diff --git a/internal/mode/static/nginx/runtime/manager_test.go b/internal/mode/static/nginx/runtime/manager_test.go index 15eb498a7f..036731e1ea 100644 --- a/internal/mode/static/nginx/runtime/manager_test.go +++ b/internal/mode/static/nginx/runtime/manager_test.go @@ -27,11 +27,12 @@ var _ = Describe("NGINX Runtime Manager", func() { }) var ( - err error - manager runtime.Manager - upstreamServers []ngxclient.UpstreamServer - ngxPlusClient *runtimefakes.FakeNginxPlusClient - process *runtimefakes.FakeProcessHandler + err error + manager runtime.Manager + upstreamServers []ngxclient.UpstreamServer + streamUpstreamServers []ngxclient.StreamUpstreamServer + ngxPlusClient *runtimefakes.FakeNginxPlusClient + process *runtimefakes.FakeProcessHandler metrics *runtimefakes.FakeMetricsCollector verifyClient *runtimefakes.FakeVerifyClient @@ -41,6 +42,9 @@ var _ = Describe("NGINX Runtime Manager", func() { upstreamServers = []ngxclient.UpstreamServer{ {}, } + streamUpstreamServers = []ngxclient.StreamUpstreamServer{ + {}, + } }) Context("Reload", func() { @@ -150,11 +154,16 @@ var _ = Describe("NGINX Runtime Manager", func() { Expect(manager.UpdateHTTPServers("test", upstreamServers)).To(Succeed()) }) + It("successfully updates stream server upstream", func() { + Expect(manager.UpdateStreamServers("test", streamUpstreamServers)).To(Succeed()) + }) + It("returns no upstreams from NGINX Plus API when upstreams are nil", func() { - upstreams, err := manager.GetUpstreams() + upstreams, streamUpstreams, err := manager.GetUpstreams() Expect(err).To(HaveOccurred()) Expect(upstreams).To(BeEmpty()) + Expect(streamUpstreams).To(BeEmpty()) }) It("successfully returns server upstreams", func() { @@ -177,22 +186,77 @@ var _ = Describe("NGINX Runtime Manager", func() { }, } + expStreamUpstreams := ngxclient.StreamUpstreams{ + "upstream1": { + Zone: "zone1", + Peers: []ngxclient.StreamPeer{ + {ID: 1, Name: "peer1-name"}, + }, + Zombies: 2, + }, + "upstream2": { + Zone: "zone2", + Peers: []ngxclient.StreamPeer{ + {ID: 2, Name: "peer2-name"}, + }, + Zombies: 1, + }, + } + ngxPlusClient.GetUpstreamsReturns(&expUpstreams, nil) + ngxPlusClient.GetStreamUpstreamsReturns(&expStreamUpstreams, nil) - upstreams, err := manager.GetUpstreams() + upstreams, streamUpstreams, err := manager.GetUpstreams() Expect(err).NotTo(HaveOccurred()) Expect(expUpstreams).To(Equal(upstreams)) + Expect(expStreamUpstreams).To(Equal(streamUpstreams)) }) It("returns an error when GetUpstreams fails", func() { ngxPlusClient.GetUpstreamsReturns(nil, errors.New("failed to get upstreams")) - upstreams, err := manager.GetUpstreams() + upstreams, streamUpstreams, err := manager.GetUpstreams() + + Expect(err).To(HaveOccurred()) + Expect(err).To(MatchError("failed to get upstreams")) + Expect(upstreams).To(BeNil()) + Expect(streamUpstreams).To(BeNil()) + }) + + It("returns an error when GetUpstreams returns nil", func() { + ngxPlusClient.GetUpstreamsReturns(nil, nil) + + upstreams, streamUpstreams, err := manager.GetUpstreams() + + Expect(err).To(HaveOccurred()) + Expect(err).To(MatchError("GET upstreams returned nil value")) + Expect(upstreams).To(BeNil()) + Expect(streamUpstreams).To(BeNil()) + }) + + It("returns an error when GetStreamUpstreams fails", func() { + ngxPlusClient.GetUpstreamsReturns(&ngxclient.Upstreams{}, nil) + ngxPlusClient.GetStreamUpstreamsReturns(nil, errors.New("failed to get upstreams")) + + upstreams, streamUpstreams, err := manager.GetUpstreams() Expect(err).To(HaveOccurred()) Expect(err).To(MatchError("failed to get upstreams")) Expect(upstreams).To(BeNil()) + Expect(streamUpstreams).To(BeNil()) + }) + + It("returns an error when GetStreamUpstreams returns nil", func() { + ngxPlusClient.GetUpstreamsReturns(&ngxclient.Upstreams{}, nil) + ngxPlusClient.GetStreamUpstreamsReturns(nil, nil) + + upstreams, streamUpstreams, err := manager.GetUpstreams() + + Expect(err).To(HaveOccurred()) + Expect(err).To(MatchError("GET stream upstreams returned nil value")) + Expect(upstreams).To(BeNil()) + Expect(streamUpstreams).To(BeNil()) }) }) @@ -202,6 +266,15 @@ var _ = Describe("NGINX Runtime Manager", func() { manager = runtime.NewManagerImpl(ngxPlusClient, nil, zap.New(), nil, nil) }) + It("should panic when fetching upstream servers", func() { + upstreams := func() { + _, _, err = manager.GetUpstreams() + } + + Expect(upstreams).To(Panic()) + Expect(err).ToNot(HaveOccurred()) + }) + It("should panic when updating HTTP upstream servers", func() { updateServers := func() { err = manager.UpdateHTTPServers("test", upstreamServers) @@ -211,12 +284,12 @@ var _ = Describe("NGINX Runtime Manager", func() { Expect(err).ToNot(HaveOccurred()) }) - It("should panic when fetching HTTP upstream servers", func() { - upstreams := func() { - _, err = manager.GetUpstreams() + It("should panic when updating stream upstream servers", func() { + updateServers := func() { + err = manager.UpdateStreamServers("test", streamUpstreamServers) } - Expect(upstreams).To(Panic()) + Expect(updateServers).To(Panic()) Expect(err).ToNot(HaveOccurred()) }) }) diff --git a/internal/mode/static/nginx/runtime/runtimefakes/fake_manager.go b/internal/mode/static/nginx/runtime/runtimefakes/fake_manager.go index 2538e32de3..ea7504a762 100644 --- a/internal/mode/static/nginx/runtime/runtimefakes/fake_manager.go +++ b/internal/mode/static/nginx/runtime/runtimefakes/fake_manager.go @@ -10,17 +10,19 @@ import ( ) type FakeManager struct { - GetUpstreamsStub func() (client.Upstreams, error) + GetUpstreamsStub func() (client.Upstreams, client.StreamUpstreams, error) getUpstreamsMutex sync.RWMutex getUpstreamsArgsForCall []struct { } getUpstreamsReturns struct { result1 client.Upstreams - result2 error + result2 client.StreamUpstreams + result3 error } getUpstreamsReturnsOnCall map[int]struct { result1 client.Upstreams - result2 error + result2 client.StreamUpstreams + result3 error } IsPlusStub func() bool isPlusMutex sync.RWMutex @@ -56,11 +58,23 @@ type FakeManager struct { updateHTTPServersReturnsOnCall map[int]struct { result1 error } + UpdateStreamServersStub func(string, []client.StreamUpstreamServer) error + updateStreamServersMutex sync.RWMutex + updateStreamServersArgsForCall []struct { + arg1 string + arg2 []client.StreamUpstreamServer + } + updateStreamServersReturns struct { + result1 error + } + updateStreamServersReturnsOnCall map[int]struct { + result1 error + } invocations map[string][][]interface{} invocationsMutex sync.RWMutex } -func (fake *FakeManager) GetUpstreams() (client.Upstreams, error) { +func (fake *FakeManager) GetUpstreams() (client.Upstreams, client.StreamUpstreams, error) { fake.getUpstreamsMutex.Lock() ret, specificReturn := fake.getUpstreamsReturnsOnCall[len(fake.getUpstreamsArgsForCall)] fake.getUpstreamsArgsForCall = append(fake.getUpstreamsArgsForCall, struct { @@ -73,9 +87,9 @@ func (fake *FakeManager) GetUpstreams() (client.Upstreams, error) { return stub() } if specificReturn { - return ret.result1, ret.result2 + return ret.result1, ret.result2, ret.result3 } - return fakeReturns.result1, fakeReturns.result2 + return fakeReturns.result1, fakeReturns.result2, fakeReturns.result3 } func (fake *FakeManager) GetUpstreamsCallCount() int { @@ -84,36 +98,39 @@ func (fake *FakeManager) GetUpstreamsCallCount() int { return len(fake.getUpstreamsArgsForCall) } -func (fake *FakeManager) GetUpstreamsCalls(stub func() (client.Upstreams, error)) { +func (fake *FakeManager) GetUpstreamsCalls(stub func() (client.Upstreams, client.StreamUpstreams, error)) { fake.getUpstreamsMutex.Lock() defer fake.getUpstreamsMutex.Unlock() fake.GetUpstreamsStub = stub } -func (fake *FakeManager) GetUpstreamsReturns(result1 client.Upstreams, result2 error) { +func (fake *FakeManager) GetUpstreamsReturns(result1 client.Upstreams, result2 client.StreamUpstreams, result3 error) { fake.getUpstreamsMutex.Lock() defer fake.getUpstreamsMutex.Unlock() fake.GetUpstreamsStub = nil fake.getUpstreamsReturns = struct { result1 client.Upstreams - result2 error - }{result1, result2} + result2 client.StreamUpstreams + result3 error + }{result1, result2, result3} } -func (fake *FakeManager) GetUpstreamsReturnsOnCall(i int, result1 client.Upstreams, result2 error) { +func (fake *FakeManager) GetUpstreamsReturnsOnCall(i int, result1 client.Upstreams, result2 client.StreamUpstreams, result3 error) { fake.getUpstreamsMutex.Lock() defer fake.getUpstreamsMutex.Unlock() fake.GetUpstreamsStub = nil if fake.getUpstreamsReturnsOnCall == nil { fake.getUpstreamsReturnsOnCall = make(map[int]struct { result1 client.Upstreams - result2 error + result2 client.StreamUpstreams + result3 error }) } fake.getUpstreamsReturnsOnCall[i] = struct { result1 client.Upstreams - result2 error - }{result1, result2} + result2 client.StreamUpstreams + result3 error + }{result1, result2, result3} } func (fake *FakeManager) IsPlus() bool { @@ -298,6 +315,73 @@ func (fake *FakeManager) UpdateHTTPServersReturnsOnCall(i int, result1 error) { }{result1} } +func (fake *FakeManager) UpdateStreamServers(arg1 string, arg2 []client.StreamUpstreamServer) error { + var arg2Copy []client.StreamUpstreamServer + if arg2 != nil { + arg2Copy = make([]client.StreamUpstreamServer, len(arg2)) + copy(arg2Copy, arg2) + } + fake.updateStreamServersMutex.Lock() + ret, specificReturn := fake.updateStreamServersReturnsOnCall[len(fake.updateStreamServersArgsForCall)] + fake.updateStreamServersArgsForCall = append(fake.updateStreamServersArgsForCall, struct { + arg1 string + arg2 []client.StreamUpstreamServer + }{arg1, arg2Copy}) + stub := fake.UpdateStreamServersStub + fakeReturns := fake.updateStreamServersReturns + fake.recordInvocation("UpdateStreamServers", []interface{}{arg1, arg2Copy}) + fake.updateStreamServersMutex.Unlock() + if stub != nil { + return stub(arg1, arg2) + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeManager) UpdateStreamServersCallCount() int { + fake.updateStreamServersMutex.RLock() + defer fake.updateStreamServersMutex.RUnlock() + return len(fake.updateStreamServersArgsForCall) +} + +func (fake *FakeManager) UpdateStreamServersCalls(stub func(string, []client.StreamUpstreamServer) error) { + fake.updateStreamServersMutex.Lock() + defer fake.updateStreamServersMutex.Unlock() + fake.UpdateStreamServersStub = stub +} + +func (fake *FakeManager) UpdateStreamServersArgsForCall(i int) (string, []client.StreamUpstreamServer) { + fake.updateStreamServersMutex.RLock() + defer fake.updateStreamServersMutex.RUnlock() + argsForCall := fake.updateStreamServersArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2 +} + +func (fake *FakeManager) UpdateStreamServersReturns(result1 error) { + fake.updateStreamServersMutex.Lock() + defer fake.updateStreamServersMutex.Unlock() + fake.UpdateStreamServersStub = nil + fake.updateStreamServersReturns = struct { + result1 error + }{result1} +} + +func (fake *FakeManager) UpdateStreamServersReturnsOnCall(i int, result1 error) { + fake.updateStreamServersMutex.Lock() + defer fake.updateStreamServersMutex.Unlock() + fake.UpdateStreamServersStub = nil + if fake.updateStreamServersReturnsOnCall == nil { + fake.updateStreamServersReturnsOnCall = make(map[int]struct { + result1 error + }) + } + fake.updateStreamServersReturnsOnCall[i] = struct { + result1 error + }{result1} +} + func (fake *FakeManager) Invocations() map[string][][]interface{} { fake.invocationsMutex.RLock() defer fake.invocationsMutex.RUnlock() @@ -309,6 +393,8 @@ func (fake *FakeManager) Invocations() map[string][][]interface{} { defer fake.reloadMutex.RUnlock() fake.updateHTTPServersMutex.RLock() defer fake.updateHTTPServersMutex.RUnlock() + fake.updateStreamServersMutex.RLock() + defer fake.updateStreamServersMutex.RUnlock() copiedInvocations := map[string][][]interface{}{} for key, value := range fake.invocations { copiedInvocations[key] = value diff --git a/internal/mode/static/nginx/runtime/runtimefakes/fake_nginx_plus_client.go b/internal/mode/static/nginx/runtime/runtimefakes/fake_nginx_plus_client.go index 3ea431d29b..8001f7f8a7 100644 --- a/internal/mode/static/nginx/runtime/runtimefakes/fake_nginx_plus_client.go +++ b/internal/mode/static/nginx/runtime/runtimefakes/fake_nginx_plus_client.go @@ -9,6 +9,18 @@ import ( ) type FakeNginxPlusClient struct { + GetStreamUpstreamsStub func() (*client.StreamUpstreams, error) + getStreamUpstreamsMutex sync.RWMutex + getStreamUpstreamsArgsForCall []struct { + } + getStreamUpstreamsReturns struct { + result1 *client.StreamUpstreams + result2 error + } + getStreamUpstreamsReturnsOnCall map[int]struct { + result1 *client.StreamUpstreams + result2 error + } GetUpstreamsStub func() (*client.Upstreams, error) getUpstreamsMutex sync.RWMutex getUpstreamsArgsForCall []struct { @@ -39,10 +51,84 @@ type FakeNginxPlusClient struct { result3 []client.UpstreamServer result4 error } + UpdateStreamServersStub func(string, []client.StreamUpstreamServer) ([]client.StreamUpstreamServer, []client.StreamUpstreamServer, []client.StreamUpstreamServer, error) + updateStreamServersMutex sync.RWMutex + updateStreamServersArgsForCall []struct { + arg1 string + arg2 []client.StreamUpstreamServer + } + updateStreamServersReturns struct { + result1 []client.StreamUpstreamServer + result2 []client.StreamUpstreamServer + result3 []client.StreamUpstreamServer + result4 error + } + updateStreamServersReturnsOnCall map[int]struct { + result1 []client.StreamUpstreamServer + result2 []client.StreamUpstreamServer + result3 []client.StreamUpstreamServer + result4 error + } invocations map[string][][]interface{} invocationsMutex sync.RWMutex } +func (fake *FakeNginxPlusClient) GetStreamUpstreams() (*client.StreamUpstreams, error) { + fake.getStreamUpstreamsMutex.Lock() + ret, specificReturn := fake.getStreamUpstreamsReturnsOnCall[len(fake.getStreamUpstreamsArgsForCall)] + fake.getStreamUpstreamsArgsForCall = append(fake.getStreamUpstreamsArgsForCall, struct { + }{}) + stub := fake.GetStreamUpstreamsStub + fakeReturns := fake.getStreamUpstreamsReturns + fake.recordInvocation("GetStreamUpstreams", []interface{}{}) + fake.getStreamUpstreamsMutex.Unlock() + if stub != nil { + return stub() + } + if specificReturn { + return ret.result1, ret.result2 + } + return fakeReturns.result1, fakeReturns.result2 +} + +func (fake *FakeNginxPlusClient) GetStreamUpstreamsCallCount() int { + fake.getStreamUpstreamsMutex.RLock() + defer fake.getStreamUpstreamsMutex.RUnlock() + return len(fake.getStreamUpstreamsArgsForCall) +} + +func (fake *FakeNginxPlusClient) GetStreamUpstreamsCalls(stub func() (*client.StreamUpstreams, error)) { + fake.getStreamUpstreamsMutex.Lock() + defer fake.getStreamUpstreamsMutex.Unlock() + fake.GetStreamUpstreamsStub = stub +} + +func (fake *FakeNginxPlusClient) GetStreamUpstreamsReturns(result1 *client.StreamUpstreams, result2 error) { + fake.getStreamUpstreamsMutex.Lock() + defer fake.getStreamUpstreamsMutex.Unlock() + fake.GetStreamUpstreamsStub = nil + fake.getStreamUpstreamsReturns = struct { + result1 *client.StreamUpstreams + result2 error + }{result1, result2} +} + +func (fake *FakeNginxPlusClient) GetStreamUpstreamsReturnsOnCall(i int, result1 *client.StreamUpstreams, result2 error) { + fake.getStreamUpstreamsMutex.Lock() + defer fake.getStreamUpstreamsMutex.Unlock() + fake.GetStreamUpstreamsStub = nil + if fake.getStreamUpstreamsReturnsOnCall == nil { + fake.getStreamUpstreamsReturnsOnCall = make(map[int]struct { + result1 *client.StreamUpstreams + result2 error + }) + } + fake.getStreamUpstreamsReturnsOnCall[i] = struct { + result1 *client.StreamUpstreams + result2 error + }{result1, result2} +} + func (fake *FakeNginxPlusClient) GetUpstreams() (*client.Upstreams, error) { fake.getUpstreamsMutex.Lock() ret, specificReturn := fake.getUpstreamsReturnsOnCall[len(fake.getUpstreamsArgsForCall)] @@ -175,13 +261,93 @@ func (fake *FakeNginxPlusClient) UpdateHTTPServersReturnsOnCall(i int, result1 [ }{result1, result2, result3, result4} } +func (fake *FakeNginxPlusClient) UpdateStreamServers(arg1 string, arg2 []client.StreamUpstreamServer) ([]client.StreamUpstreamServer, []client.StreamUpstreamServer, []client.StreamUpstreamServer, error) { + var arg2Copy []client.StreamUpstreamServer + if arg2 != nil { + arg2Copy = make([]client.StreamUpstreamServer, len(arg2)) + copy(arg2Copy, arg2) + } + fake.updateStreamServersMutex.Lock() + ret, specificReturn := fake.updateStreamServersReturnsOnCall[len(fake.updateStreamServersArgsForCall)] + fake.updateStreamServersArgsForCall = append(fake.updateStreamServersArgsForCall, struct { + arg1 string + arg2 []client.StreamUpstreamServer + }{arg1, arg2Copy}) + stub := fake.UpdateStreamServersStub + fakeReturns := fake.updateStreamServersReturns + fake.recordInvocation("UpdateStreamServers", []interface{}{arg1, arg2Copy}) + fake.updateStreamServersMutex.Unlock() + if stub != nil { + return stub(arg1, arg2) + } + if specificReturn { + return ret.result1, ret.result2, ret.result3, ret.result4 + } + return fakeReturns.result1, fakeReturns.result2, fakeReturns.result3, fakeReturns.result4 +} + +func (fake *FakeNginxPlusClient) UpdateStreamServersCallCount() int { + fake.updateStreamServersMutex.RLock() + defer fake.updateStreamServersMutex.RUnlock() + return len(fake.updateStreamServersArgsForCall) +} + +func (fake *FakeNginxPlusClient) UpdateStreamServersCalls(stub func(string, []client.StreamUpstreamServer) ([]client.StreamUpstreamServer, []client.StreamUpstreamServer, []client.StreamUpstreamServer, error)) { + fake.updateStreamServersMutex.Lock() + defer fake.updateStreamServersMutex.Unlock() + fake.UpdateStreamServersStub = stub +} + +func (fake *FakeNginxPlusClient) UpdateStreamServersArgsForCall(i int) (string, []client.StreamUpstreamServer) { + fake.updateStreamServersMutex.RLock() + defer fake.updateStreamServersMutex.RUnlock() + argsForCall := fake.updateStreamServersArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2 +} + +func (fake *FakeNginxPlusClient) UpdateStreamServersReturns(result1 []client.StreamUpstreamServer, result2 []client.StreamUpstreamServer, result3 []client.StreamUpstreamServer, result4 error) { + fake.updateStreamServersMutex.Lock() + defer fake.updateStreamServersMutex.Unlock() + fake.UpdateStreamServersStub = nil + fake.updateStreamServersReturns = struct { + result1 []client.StreamUpstreamServer + result2 []client.StreamUpstreamServer + result3 []client.StreamUpstreamServer + result4 error + }{result1, result2, result3, result4} +} + +func (fake *FakeNginxPlusClient) UpdateStreamServersReturnsOnCall(i int, result1 []client.StreamUpstreamServer, result2 []client.StreamUpstreamServer, result3 []client.StreamUpstreamServer, result4 error) { + fake.updateStreamServersMutex.Lock() + defer fake.updateStreamServersMutex.Unlock() + fake.UpdateStreamServersStub = nil + if fake.updateStreamServersReturnsOnCall == nil { + fake.updateStreamServersReturnsOnCall = make(map[int]struct { + result1 []client.StreamUpstreamServer + result2 []client.StreamUpstreamServer + result3 []client.StreamUpstreamServer + result4 error + }) + } + fake.updateStreamServersReturnsOnCall[i] = struct { + result1 []client.StreamUpstreamServer + result2 []client.StreamUpstreamServer + result3 []client.StreamUpstreamServer + result4 error + }{result1, result2, result3, result4} +} + func (fake *FakeNginxPlusClient) Invocations() map[string][][]interface{} { fake.invocationsMutex.RLock() defer fake.invocationsMutex.RUnlock() + fake.getStreamUpstreamsMutex.RLock() + defer fake.getStreamUpstreamsMutex.RUnlock() fake.getUpstreamsMutex.RLock() defer fake.getUpstreamsMutex.RUnlock() fake.updateHTTPServersMutex.RLock() defer fake.updateHTTPServersMutex.RUnlock() + fake.updateStreamServersMutex.RLock() + defer fake.updateStreamServersMutex.RUnlock() copiedInvocations := map[string][][]interface{}{} for key, value := range fake.invocations { copiedInvocations[key] = value