From 5569e94cdaad4e48c058839f22bab0eee1482dbc Mon Sep 17 00:00:00 2001
From: Luca Comellini <luca.com@gmail.com>
Date: Tue, 19 Sep 2023 11:17:57 -0700
Subject: [PATCH] Refactor NGINX Client to use options (#153)

---
 client/nginx.go                | 104 ++++++++++++++++++++++-----------
 client/nginx_test.go           |  71 ++++++++++++++++++++++
 tests/client_no_stream_test.go |   4 +-
 tests/client_test.go           |  41 +++++--------
 4 files changed, 157 insertions(+), 63 deletions(-)

diff --git a/client/nginx.go b/client/nginx.go
index c6fb9f68..206ecf8f 100644
--- a/client/nginx.go
+++ b/client/nginx.go
@@ -41,11 +41,14 @@ var ErrUnsupportedVer = errors.New("API version of the client is not supported b
 
 // NginxClient lets you access NGINX Plus API.
 type NginxClient struct {
-	version     int
+	apiVersion  int
 	apiEndpoint string
 	httpClient  *http.Client
+	checkAPI    bool
 }
 
+type Option func(*NginxClient)
+
 type versions []int
 
 // UpstreamServer lets you configure HTTP upstreams.
@@ -508,35 +511,66 @@ type WorkersHTTP struct {
 	HTTPRequests HTTPRequests `json:"requests"`
 }
 
-// NewNginxClient creates an NginxClient with the latest supported version.
-func NewNginxClient(httpClient *http.Client, apiEndpoint string) (*NginxClient, error) {
-	return NewNginxClientWithVersion(httpClient, apiEndpoint, APIVersion)
+// WithHTTPClient sets the HTTP client to use for accessing the API.
+func WithHTTPClient(httpClient *http.Client) Option {
+	return func(o *NginxClient) {
+		o.httpClient = httpClient
+	}
 }
 
-// NewNginxClientWithVersion creates an NginxClient with the given version of NGINX Plus API.
-func NewNginxClientWithVersion(httpClient *http.Client, apiEndpoint string, version int) (*NginxClient, error) {
-	if !versionSupported(version) {
-		return nil, fmt.Errorf("API version %v is not supported by the client", version)
+// WithAPIVersion sets the API version to use for accessing the API.
+func WithAPIVersion(apiVersion int) Option {
+	return func(o *NginxClient) {
+		o.apiVersion = apiVersion
 	}
-	versions, err := getAPIVersions(httpClient, apiEndpoint)
-	if err != nil {
-		return nil, fmt.Errorf("error accessing the API: %w", err)
+}
+
+// WithCheckAPI sets the flag to check the API version of the server.
+func WithCheckAPI() Option {
+	return func(o *NginxClient) {
+		o.checkAPI = true
 	}
-	found := false
-	for _, v := range *versions {
-		if v == version {
-			found = true
-			break
-		}
+}
+
+// NewNginxClient creates a new NginxClient.
+func NewNginxClient(apiEndpoint string, opts ...Option) (*NginxClient, error) {
+	c := &NginxClient{
+		httpClient:  http.DefaultClient,
+		apiEndpoint: apiEndpoint,
+		apiVersion:  APIVersion,
+		checkAPI:    false,
 	}
-	if !found {
-		return nil, ErrUnsupportedVer
+
+	for _, opt := range opts {
+		opt(c)
 	}
-	return &NginxClient{
-		apiEndpoint: apiEndpoint,
-		httpClient:  httpClient,
-		version:     version,
-	}, nil
+
+	if c.httpClient == nil {
+		return nil, fmt.Errorf("http client is not set")
+	}
+
+	if !versionSupported(c.apiVersion) {
+		return nil, fmt.Errorf("API version %v is not supported by the client", c.apiVersion)
+	}
+
+	if c.checkAPI {
+		versions, err := getAPIVersions(c.httpClient, apiEndpoint)
+		if err != nil {
+			return nil, fmt.Errorf("error accessing the API: %w", err)
+		}
+		found := false
+		for _, v := range *versions {
+			if v == c.apiVersion {
+				found = true
+				break
+			}
+		}
+		if !found {
+			return nil, fmt.Errorf("API version %v is not supported by the server", c.apiVersion)
+		}
+	}
+
+	return c, nil
 }
 
 func versionSupported(n int) bool {
@@ -807,7 +841,7 @@ func (client *NginxClient) get(path string, data interface{}) error {
 	ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
 	defer cancel()
 
-	url := fmt.Sprintf("%v/%v/%v", client.apiEndpoint, client.version, path)
+	url := fmt.Sprintf("%v/%v/%v", client.apiEndpoint, client.apiVersion, path)
 
 	req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
 	if err != nil {
@@ -841,7 +875,7 @@ func (client *NginxClient) post(path string, input interface{}) error {
 	ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
 	defer cancel()
 
-	url := fmt.Sprintf("%v/%v/%v", client.apiEndpoint, client.version, path)
+	url := fmt.Sprintf("%v/%v/%v", client.apiEndpoint, client.apiVersion, path)
 
 	jsonInput, err := json.Marshal(input)
 	if err != nil {
@@ -873,7 +907,7 @@ func (client *NginxClient) delete(path string, expectedStatusCode int) error {
 	ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
 	defer cancel()
 
-	path = fmt.Sprintf("%v/%v/%v/", client.apiEndpoint, client.version, path)
+	path = fmt.Sprintf("%v/%v/%v/", client.apiEndpoint, client.apiVersion, path)
 
 	req, err := http.NewRequestWithContext(ctx, http.MethodDelete, path, nil)
 	if err != nil {
@@ -898,7 +932,7 @@ func (client *NginxClient) patch(path string, input interface{}, expectedStatusC
 	ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
 	defer cancel()
 
-	path = fmt.Sprintf("%v/%v/%v/", client.apiEndpoint, client.version, path)
+	path = fmt.Sprintf("%v/%v/%v/", client.apiEndpoint, client.apiVersion, path)
 
 	jsonInput, err := json.Marshal(input)
 	if err != nil {
@@ -1359,7 +1393,7 @@ func (client *NginxClient) GetStreamZoneSync() (*StreamZoneSync, error) {
 // GetLocationZones returns http/location_zones stats.
 func (client *NginxClient) GetLocationZones() (*LocationZones, error) {
 	var locationZones LocationZones
-	if client.version < 5 {
+	if client.apiVersion < 5 {
 		return &locationZones, nil
 	}
 	err := client.get("http/location_zones", &locationZones)
@@ -1373,7 +1407,7 @@ func (client *NginxClient) GetLocationZones() (*LocationZones, error) {
 // GetResolvers returns Resolvers stats.
 func (client *NginxClient) GetResolvers() (*Resolvers, error) {
 	var resolvers Resolvers
-	if client.version < 5 {
+	if client.apiVersion < 5 {
 		return &resolvers, nil
 	}
 	err := client.get("resolvers", &resolvers)
@@ -1596,7 +1630,7 @@ func (client *NginxClient) UpdateStreamServer(upstream string, server StreamUpst
 
 // Version returns client's current N+ API version.
 func (client *NginxClient) Version() int {
-	return client.version
+	return client.apiVersion
 }
 
 func addPortToServer(server string) string {
@@ -1618,7 +1652,7 @@ func addPortToServer(server string) string {
 // GetHTTPLimitReqs returns http/limit_reqs stats.
 func (client *NginxClient) GetHTTPLimitReqs() (*HTTPLimitRequests, error) {
 	var limitReqs HTTPLimitRequests
-	if client.version < 6 {
+	if client.apiVersion < 6 {
 		return &limitReqs, nil
 	}
 	err := client.get("http/limit_reqs", &limitReqs)
@@ -1631,7 +1665,7 @@ func (client *NginxClient) GetHTTPLimitReqs() (*HTTPLimitRequests, error) {
 // GetHTTPConnectionsLimit returns http/limit_conns stats.
 func (client *NginxClient) GetHTTPConnectionsLimit() (*HTTPLimitConnections, error) {
 	var limitConns HTTPLimitConnections
-	if client.version < 6 {
+	if client.apiVersion < 6 {
 		return &limitConns, nil
 	}
 	err := client.get("http/limit_conns", &limitConns)
@@ -1644,7 +1678,7 @@ func (client *NginxClient) GetHTTPConnectionsLimit() (*HTTPLimitConnections, err
 // GetStreamConnectionsLimit returns stream/limit_conns stats.
 func (client *NginxClient) GetStreamConnectionsLimit() (*StreamLimitConnections, error) {
 	var limitConns StreamLimitConnections
-	if client.version < 6 {
+	if client.apiVersion < 6 {
 		return &limitConns, nil
 	}
 	err := client.get("stream/limit_conns", &limitConns)
@@ -1663,7 +1697,7 @@ func (client *NginxClient) GetStreamConnectionsLimit() (*StreamLimitConnections,
 // GetWorkers returns workers stats.
 func (client *NginxClient) GetWorkers() ([]*Workers, error) {
 	var workers []*Workers
-	if client.version < 9 {
+	if client.apiVersion < 9 {
 		return workers, nil
 	}
 	err := client.get("workers", &workers)
diff --git a/client/nginx_test.go b/client/nginx_test.go
index b8aa4eb8..0acf9914 100644
--- a/client/nginx_test.go
+++ b/client/nginx_test.go
@@ -1,6 +1,8 @@
 package client
 
 import (
+	"net/http"
+	"net/http/httptest"
 	"reflect"
 	"testing"
 )
@@ -518,3 +520,72 @@ func TestHaveSameParametersForStream(t *testing.T) {
 		}
 	}
 }
+
+func TestClientWithCheckAPI(t *testing.T) {
+	// Create a test server that returns supported API versions
+	ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+		_, err := w.Write([]byte(`[4, 5, 6, 7]`))
+		if err != nil {
+			t.Fatalf("unexpected error: %v", err)
+		}
+	}))
+	defer ts.Close()
+
+	// Test creating a new client with a supported API version on the server
+	client, err := NewNginxClient(ts.URL, WithAPIVersion(7), WithCheckAPI())
+	if err != nil {
+		t.Fatalf("unexpected error: %v", err)
+	}
+	if client == nil {
+		t.Fatalf("client is nil")
+	}
+
+	// Test creating a new client with an unsupported API version on the server
+	client, err = NewNginxClient(ts.URL, WithAPIVersion(8), WithCheckAPI())
+	if err == nil {
+		t.Fatalf("expected error, but got nil")
+	}
+	if client != nil {
+		t.Fatalf("expected client to be nil, but got %v", client)
+	}
+}
+
+func TestClientWithAPIVersion(t *testing.T) {
+	// Test creating a new client with a supported API version on the client
+	client, err := NewNginxClient("http://api-url", WithAPIVersion(8))
+	if err != nil {
+		t.Fatalf("unexpected error: %v", err)
+	}
+	if client == nil {
+		t.Fatalf("client is nil")
+	}
+
+	// Test creating a new client with an unsupported API version on the client
+	client, err = NewNginxClient("http://api-url", WithAPIVersion(3))
+	if err == nil {
+		t.Fatalf("expected error, but got nil")
+	}
+	if client != nil {
+		t.Fatalf("expected client to be nil, but got %v", client)
+	}
+}
+
+func TestClientWithHTTPClient(t *testing.T) {
+	// Test creating a new client passing a custom HTTP client
+	client, err := NewNginxClient("http://api-url", WithHTTPClient(&http.Client{}))
+	if err != nil {
+		t.Fatalf("unexpected error: %v", err)
+	}
+	if client == nil {
+		t.Fatalf("client is nil")
+	}
+
+	// Test creating a new client passing a nil HTTP client
+	client, err = NewNginxClient("http://api-url", WithHTTPClient(nil))
+	if err == nil {
+		t.Fatalf("expected error, but got nil")
+	}
+	if client != nil {
+		t.Fatalf("expected client to be nil, but got %v", client)
+	}
+}
diff --git a/tests/client_no_stream_test.go b/tests/client_no_stream_test.go
index 6a7221a4..cb29d465 100644
--- a/tests/client_no_stream_test.go
+++ b/tests/client_no_stream_test.go
@@ -1,7 +1,6 @@
 package tests
 
 import (
-	"net/http"
 	"testing"
 
 	"github.com/nginxinc/nginx-plus-go-client/client"
@@ -13,8 +12,7 @@ import (
 // The API returns a special error code that we can use to determine if the API
 // is misconfigured or of the stream block is missing.
 func TestStatsNoStream(t *testing.T) {
-	httpClient := &http.Client{}
-	c, err := client.NewNginxClient(httpClient, helpers.GetAPIEndpoint())
+	c, err := client.NewNginxClient(helpers.GetAPIEndpoint())
 	if err != nil {
 		t.Fatalf("Error connecting to nginx: %v", err)
 	}
diff --git a/tests/client_test.go b/tests/client_test.go
index 27fb84e9..f580ad71 100644
--- a/tests/client_test.go
+++ b/tests/client_test.go
@@ -2,7 +2,6 @@ package tests
 
 import (
 	"net"
-	"net/http"
 	"reflect"
 	"testing"
 	"time"
@@ -34,8 +33,10 @@ var (
 )
 
 func TestStreamClient(t *testing.T) {
-	httpClient := &http.Client{}
-	c, err := client.NewNginxClient(httpClient, helpers.GetAPIEndpoint())
+	c, err := client.NewNginxClient(
+		helpers.GetAPIEndpoint(),
+		client.WithCheckAPI(),
+	)
 	if err != nil {
 		t.Fatalf("Error when creating a client: %v", err)
 	}
@@ -254,8 +255,7 @@ func TestStreamClient(t *testing.T) {
 }
 
 func TestStreamUpstreamServer(t *testing.T) {
-	httpClient := &http.Client{}
-	c, err := client.NewNginxClient(httpClient, helpers.GetAPIEndpoint())
+	c, err := client.NewNginxClient(helpers.GetAPIEndpoint())
 	if err != nil {
 		t.Fatalf("Error connecting to nginx: %v", err)
 	}
@@ -302,8 +302,7 @@ func TestStreamUpstreamServer(t *testing.T) {
 }
 
 func TestClient(t *testing.T) {
-	httpClient := &http.Client{}
-	c, err := client.NewNginxClient(httpClient, helpers.GetAPIEndpoint())
+	c, err := client.NewNginxClient(helpers.GetAPIEndpoint())
 	if err != nil {
 		t.Fatalf("Error when creating a client: %v", err)
 	}
@@ -529,8 +528,7 @@ func TestClient(t *testing.T) {
 }
 
 func TestUpstreamServer(t *testing.T) {
-	httpClient := &http.Client{}
-	c, err := client.NewNginxClient(httpClient, helpers.GetAPIEndpoint())
+	c, err := client.NewNginxClient(helpers.GetAPIEndpoint())
 	if err != nil {
 		t.Fatalf("Error connecting to nginx: %v", err)
 	}
@@ -578,8 +576,7 @@ func TestUpstreamServer(t *testing.T) {
 }
 
 func TestStats(t *testing.T) {
-	httpClient := &http.Client{}
-	c, err := client.NewNginxClient(httpClient, helpers.GetAPIEndpoint())
+	c, err := client.NewNginxClient(helpers.GetAPIEndpoint())
 	if err != nil {
 		t.Fatalf("Error connecting to nginx: %v", err)
 	}
@@ -720,8 +717,7 @@ func TestStats(t *testing.T) {
 }
 
 func TestUpstreamServerDefaultParameters(t *testing.T) {
-	httpClient := &http.Client{}
-	c, err := client.NewNginxClient(httpClient, helpers.GetAPIEndpoint())
+	c, err := client.NewNginxClient(helpers.GetAPIEndpoint())
 	if err != nil {
 		t.Fatalf("Error connecting to nginx: %v", err)
 	}
@@ -770,8 +766,7 @@ func TestUpstreamServerDefaultParameters(t *testing.T) {
 }
 
 func TestStreamStats(t *testing.T) {
-	httpClient := &http.Client{}
-	c, err := client.NewNginxClient(httpClient, helpers.GetAPIEndpoint())
+	c, err := client.NewNginxClient(helpers.GetAPIEndpoint())
 	if err != nil {
 		t.Fatalf("Error connecting to nginx: %v", err)
 	}
@@ -848,8 +843,7 @@ func TestStreamStats(t *testing.T) {
 }
 
 func TestStreamUpstreamServerDefaultParameters(t *testing.T) {
-	httpClient := &http.Client{}
-	c, err := client.NewNginxClient(httpClient, helpers.GetAPIEndpoint())
+	c, err := client.NewNginxClient(helpers.GetAPIEndpoint())
 	if err != nil {
 		t.Fatalf("Error connecting to nginx: %v", err)
 	}
@@ -897,8 +891,7 @@ func TestStreamUpstreamServerDefaultParameters(t *testing.T) {
 
 func TestKeyValue(t *testing.T) {
 	zoneName := "zone_one"
-	httpClient := &http.Client{}
-	c, err := client.NewNginxClient(httpClient, helpers.GetAPIEndpoint())
+	c, err := client.NewNginxClient(helpers.GetAPIEndpoint())
 	if err != nil {
 		t.Fatalf("Error connecting to nginx: %v", err)
 	}
@@ -995,8 +988,7 @@ func TestKeyValue(t *testing.T) {
 
 func TestKeyValueStream(t *testing.T) {
 	zoneName := "zone_one_stream"
-	httpClient := &http.Client{}
-	c, err := client.NewNginxClient(httpClient, helpers.GetAPIEndpoint())
+	c, err := client.NewNginxClient(helpers.GetAPIEndpoint())
 	if err != nil {
 		t.Fatalf("Error connecting to nginx: %v", err)
 	}
@@ -1092,12 +1084,12 @@ func TestKeyValueStream(t *testing.T) {
 }
 
 func TestStreamZoneSync(t *testing.T) {
-	c1, err := client.NewNginxClient(&http.Client{}, helpers.GetAPIEndpoint())
+	c1, err := client.NewNginxClient(helpers.GetAPIEndpoint())
 	if err != nil {
 		t.Fatalf("Error connecting to nginx: %v", err)
 	}
 
-	c2, err := client.NewNginxClient(&http.Client{}, helpers.GetAPIEndpointOfHelper())
+	c2, err := client.NewNginxClient(helpers.GetAPIEndpointOfHelper())
 	if err != nil {
 		t.Fatalf("Error connecting to nginx: %v", err)
 	}
@@ -1218,8 +1210,7 @@ func compareStreamUpstreamServers(x []client.StreamUpstreamServer, y []client.St
 }
 
 func TestUpstreamServerWithDrain(t *testing.T) {
-	httpClient := &http.Client{}
-	c, err := client.NewNginxClient(httpClient, helpers.GetAPIEndpoint())
+	c, err := client.NewNginxClient(helpers.GetAPIEndpoint())
 	if err != nil {
 		t.Fatalf("Error connecting to nginx: %v", err)
 	}