From 19b8ad15393cff66efa1d228267f342a78ddc22a Mon Sep 17 00:00:00 2001 From: Jakub Jarosz <99677300+jjngx@users.noreply.github.com> Date: Mon, 8 May 2023 21:52:57 +0100 Subject: [PATCH] Simplify validators (#3818) * Simplify validators --- pkg/apis/configuration/validation/common.go | 54 ++- .../validation/globalconfiguration.go | 15 +- pkg/apis/configuration/validation/policy.go | 148 +++----- .../configuration/validation/policy_test.go | 81 ++++- .../validation/transportserver.go | 140 +++----- .../configuration/validation/virtualserver.go | 318 ++++++------------ 6 files changed, 289 insertions(+), 467 deletions(-) diff --git a/pkg/apis/configuration/validation/common.go b/pkg/apis/configuration/validation/common.go index ed6dfc8106..e094b1fa95 100644 --- a/pkg/apis/configuration/validation/common.go +++ b/pkg/apis/configuration/validation/common.go @@ -27,13 +27,11 @@ func ValidateEscapedString(body string, examples ...string) error { } func validateVariable(nVar string, validVars map[string]bool, fieldPath *field.Path) field.ErrorList { - allErrs := field.ErrorList{} - if !validVars[nVar] { msg := fmt.Sprintf("'%v' contains an invalid NGINX variable. Accepted variables are: %v", nVar, mapToPrettyString(validVars)) - allErrs = append(allErrs, field.Invalid(fieldPath, nVar, msg)) + return field.ErrorList{field.Invalid(fieldPath, nVar, msg)} } - return allErrs + return nil } // isValidSpecialHeaderLikeVariable validates special variables $http_, $jwt_header_, $jwt_claim_ @@ -103,10 +101,8 @@ func validateSpecialVariable(nVar string, fieldPath *field.Path, isPlus bool) fi } func validateStringWithVariables(str string, fieldPath *field.Path, specialVars []string, validVars map[string]bool, isPlus bool) field.ErrorList { - allErrs := field.ErrorList{} - if strings.HasSuffix(str, "$") { - return append(allErrs, field.Invalid(fieldPath, str, "must not end with $")) + return field.ErrorList{field.Invalid(fieldPath, str, "must not end with $")} } for i, c := range str { @@ -114,15 +110,16 @@ func validateStringWithVariables(str string, fieldPath *field.Path, specialVars msg := "variables must be enclosed in curly braces, for example ${host}" if str[i+1] != '{' { - return append(allErrs, field.Invalid(fieldPath, str, msg)) + return field.ErrorList{field.Invalid(fieldPath, str, msg)} } if !strings.Contains(str[i+1:], "}") { - return append(allErrs, field.Invalid(fieldPath, str, msg)) + return field.ErrorList{field.Invalid(fieldPath, str, msg)} } } } + allErrs := field.ErrorList{} nginxVars := captureVariables(str) for _, nVar := range nginxVars { special := false @@ -144,67 +141,56 @@ func validateStringWithVariables(str string, fieldPath *field.Path, specialVars } func validateTime(time string, fieldPath *field.Path) field.ErrorList { - allErrs := field.ErrorList{} - if time == "" { - return allErrs + return nil } - if _, err := configs.ParseTime(time); err != nil { - return append(allErrs, field.Invalid(fieldPath, time, err.Error())) + return field.ErrorList{field.Invalid(fieldPath, time, err.Error())} } - - return allErrs + return nil } // http://nginx.org/en/docs/syntax.html const offsetErrMsg = "must consist of numeric characters followed by a valid size suffix. 'k|K|m|M|g|G" func validateOffset(offset string, fieldPath *field.Path) field.ErrorList { - allErrs := field.ErrorList{} - if offset == "" { - return allErrs + return nil } if _, err := configs.ParseOffset(offset); err != nil { msg := validation.RegexError(offsetErrMsg, configs.OffsetFmt, "16", "32k", "64M", "2G") - return append(allErrs, field.Invalid(fieldPath, offset, msg)) + return field.ErrorList{field.Invalid(fieldPath, offset, msg)} } - - return allErrs + return nil } // http://nginx.org/en/docs/syntax.html const sizeErrMsg = "must consist of numeric characters followed by a valid size suffix. 'k|K|m|M" func validateSize(size string, fieldPath *field.Path) field.ErrorList { - allErrs := field.ErrorList{} - if size == "" { - return allErrs + return nil } if _, err := configs.ParseSize(size); err != nil { msg := validation.RegexError(sizeErrMsg, configs.SizeFmt, "16", "32k", "64M") - return append(allErrs, field.Invalid(fieldPath, size, msg)) + return field.ErrorList{field.Invalid(fieldPath, size, msg)} } - return allErrs + return nil } // validateSecretName checks if a secret name is valid. // It performs the same validation as ValidateSecretName from k8s.io/kubernetes/pkg/apis/core/validation/validation.go. func validateSecretName(name string, fieldPath *field.Path) field.ErrorList { - allErrs := field.ErrorList{} - if name == "" { - return allErrs + return nil } + allErrs := field.ErrorList{} for _, msg := range validation.IsDNS1123Subdomain(name) { allErrs = append(allErrs, field.Invalid(fieldPath, name, msg)) } - return allErrs } @@ -220,11 +206,9 @@ func mapToPrettyString(m map[string]bool) string { // ValidateParameter validates a parameter against a map of valid parameters for the directive func ValidateParameter(nPar string, validParams map[string]bool, fieldPath *field.Path) field.ErrorList { - allErrs := field.ErrorList{} - if !validParams[nPar] { msg := fmt.Sprintf("'%v' contains an invalid NGINX parameter. Accepted parameters are: %v", nPar, mapToPrettyString(validParams)) - allErrs = append(allErrs, field.Invalid(fieldPath, nPar, msg)) + return field.ErrorList{field.Invalid(fieldPath, nPar, msg)} } - return allErrs + return nil } diff --git a/pkg/apis/configuration/validation/globalconfiguration.go b/pkg/apis/configuration/validation/globalconfiguration.go index a4454865e1..ccee08d1f8 100644 --- a/pkg/apis/configuration/validation/globalconfiguration.go +++ b/pkg/apis/configuration/validation/globalconfiguration.go @@ -63,9 +63,7 @@ func generatePortProtocolKey(port int, protocol string) string { } func (gcv *GlobalConfigurationValidator) validateListener(listener v1alpha1.Listener, fieldPath *field.Path) field.ErrorList { - allErrs := field.ErrorList{} - - allErrs = append(allErrs, validateGlobalConfigurationListenerName(listener.Name, fieldPath.Child("name"))...) + allErrs := validateGlobalConfigurationListenerName(listener.Name, fieldPath.Child("name")) allErrs = append(allErrs, gcv.validateListenerPort(listener.Port, fieldPath.Child("port"))...) allErrs = append(allErrs, validateListenerProtocol(listener.Protocol, fieldPath.Child("protocol"))...) @@ -73,26 +71,21 @@ func (gcv *GlobalConfigurationValidator) validateListener(listener v1alpha1.List } func validateGlobalConfigurationListenerName(name string, fieldPath *field.Path) field.ErrorList { - allErrs := field.ErrorList{} - if name == v1alpha1.TLSPassthroughListenerName { - return append(allErrs, field.Forbidden(fieldPath, "is the name of a built-in listener")) + return field.ErrorList{field.Forbidden(fieldPath, "is the name of a built-in listener")} } - return validateListenerName(name, fieldPath) } func (gcv *GlobalConfigurationValidator) validateListenerPort(port int, fieldPath *field.Path) field.ErrorList { - allErrs := field.ErrorList{} - if gcv.forbiddenListenerPorts[port] { msg := fmt.Sprintf("port %v is forbidden", port) - return append(allErrs, field.Forbidden(fieldPath, msg)) + return field.ErrorList{field.Forbidden(fieldPath, msg)} } + allErrs := field.ErrorList{} for _, msg := range validation.IsValidPortNum(port) { allErrs = append(allErrs, field.Invalid(fieldPath, port, msg)) } - return allErrs } diff --git a/pkg/apis/configuration/validation/policy.go b/pkg/apis/configuration/validation/policy.go index c822527d1d..73bddd0d15 100644 --- a/pkg/apis/configuration/validation/policy.go +++ b/pkg/apis/configuration/validation/policy.go @@ -122,9 +122,7 @@ func validateAccessControl(accessControl *v1.AccessControl, fieldPath *field.Pat } func validateRateLimit(rateLimit *v1.RateLimit, fieldPath *field.Path, isPlus bool) field.ErrorList { - allErrs := field.ErrorList{} - - allErrs = append(allErrs, validateRateLimitZoneSize(rateLimit.ZoneSize, fieldPath.Child("zoneSize"))...) + allErrs := validateRateLimitZoneSize(rateLimit.ZoneSize, fieldPath.Child("zoneSize")) allErrs = append(allErrs, validateRate(rateLimit.Rate, fieldPath.Child("rate"))...) allErrs = append(allErrs, validateRateLimitKey(rateLimit.Key, fieldPath.Child("key"), isPlus)...) @@ -202,15 +200,12 @@ func validateBasic(basic *v1.BasicAuth, fieldPath *field.Path) field.ErrorList { } func validateIngressMTLS(ingressMTLS *v1.IngressMTLS, fieldPath *field.Path) field.ErrorList { - allErrs := field.ErrorList{} - if ingressMTLS.ClientCertSecret == "" { - return append(allErrs, field.Required(fieldPath.Child("clientCertSecret"), "")) + return field.ErrorList{field.Required(fieldPath.Child("clientCertSecret"), "")} } - allErrs = append(allErrs, validateSecretName(ingressMTLS.ClientCertSecret, fieldPath.Child("clientCertSecret"))...) + allErrs := validateSecretName(ingressMTLS.ClientCertSecret, fieldPath.Child("clientCertSecret")) allErrs = append(allErrs, validateIngressMTLSVerifyClient(ingressMTLS.VerifyClient, fieldPath.Child("verifyClient"))...) - if ingressMTLS.VerifyDepth != nil { allErrs = append(allErrs, validatePositiveIntOrZero(*ingressMTLS.VerifyDepth, fieldPath.Child("verifyDepth"))...) } @@ -218,9 +213,7 @@ func validateIngressMTLS(ingressMTLS *v1.IngressMTLS, fieldPath *field.Path) fie } func validateEgressMTLS(egressMTLS *v1.EgressMTLS, fieldPath *field.Path) field.ErrorList { - allErrs := field.ErrorList{} - - allErrs = append(allErrs, validateSecretName(egressMTLS.TLSSecret, fieldPath.Child("tlsSecret"))...) + allErrs := validateSecretName(egressMTLS.TLSSecret, fieldPath.Child("tlsSecret")) if egressMTLS.VerifyServer && egressMTLS.TrustedCertSecret == "" { return append(allErrs, field.Required(fieldPath.Child("trustedCertSecret"), "must be set when verifyServer is 'true'")) @@ -237,36 +230,32 @@ func validateEgressMTLS(egressMTLS *v1.EgressMTLS, fieldPath *field.Path) field. } func validateOIDC(oidc *v1.OIDC, fieldPath *field.Path) field.ErrorList { - allErrs := field.ErrorList{} - if oidc.AuthEndpoint == "" { - return append(allErrs, field.Required(fieldPath.Child("authEndpoint"), "")) + return field.ErrorList{field.Required(fieldPath.Child("authEndpoint"), "")} } if oidc.TokenEndpoint == "" { - return append(allErrs, field.Required(fieldPath.Child("tokenEndpoint"), "")) + return field.ErrorList{field.Required(fieldPath.Child("tokenEndpoint"), "")} } if oidc.JWKSURI == "" { - return append(allErrs, field.Required(fieldPath.Child("jwksURI"), "")) + return field.ErrorList{field.Required(fieldPath.Child("jwksURI"), "")} } if oidc.ClientID == "" { - return append(allErrs, field.Required(fieldPath.Child("clientID"), "")) + return field.ErrorList{field.Required(fieldPath.Child("clientID"), "")} } if oidc.ClientSecret == "" { - return append(allErrs, field.Required(fieldPath.Child("clientSecret"), "")) + return field.ErrorList{field.Required(fieldPath.Child("clientSecret"), "")} } + allErrs := field.ErrorList{} if oidc.Scope != "" { allErrs = append(allErrs, validateOIDCScope(oidc.Scope, fieldPath.Child("scope"))...) } - if oidc.RedirectURI != "" { allErrs = append(allErrs, validatePath(oidc.RedirectURI, fieldPath.Child("redirectURI"))...) } - if oidc.ZoneSyncLeeway != nil { allErrs = append(allErrs, validatePositiveIntOrZero(*oidc.ZoneSyncLeeway, fieldPath.Child("zoneSyncLeeway"))...) } - if oidc.AuthExtraArgs != nil { allErrs = append(allErrs, validateQueryString(strings.Join(oidc.AuthExtraArgs, "&"), fieldPath.Child("authExtraArgs"))...) } @@ -352,12 +341,11 @@ var validScopes = map[string]bool{ // https://openid.net/specs/openid-connect-core-1_0.html#ScopeClaims func validateOIDCScope(scope string, fieldPath *field.Path) field.ErrorList { - allErrs := field.ErrorList{} - if !strings.Contains(scope, "openid") { - return append(allErrs, field.Required(fieldPath, "openid scope")) + return field.ErrorList{field.Required(fieldPath, "openid scope")} } + allErrs := field.ErrorList{} s := strings.Split(scope, "+") for _, v := range s { if !validScopes[v] { @@ -365,29 +353,22 @@ func validateOIDCScope(scope string, fieldPath *field.Path) field.ErrorList { allErrs = append(allErrs, field.Invalid(fieldPath, v, msg)) } } - return allErrs } func validateURL(name string, fieldPath *field.Path) field.ErrorList { - allErrs := field.ErrorList{} - u, err := url.Parse(name) if err != nil { - return append(allErrs, field.Invalid(fieldPath, name, err.Error())) + return field.ErrorList{field.Invalid(fieldPath, name, err.Error())} } - var msg string if u.Scheme == "" { - msg = "scheme required, please use the prefix http(s)://" - return append(allErrs, field.Invalid(fieldPath, name, msg)) + return field.ErrorList{field.Invalid(fieldPath, name, "scheme required, please use the prefix http(s)://")} } if u.Host == "" { - msg = "hostname required" - return append(allErrs, field.Invalid(fieldPath, name, msg)) + return field.ErrorList{field.Invalid(fieldPath, name, "hostname required")} } if u.Path == "" { - msg = "path required" - return append(allErrs, field.Invalid(fieldPath, name, msg)) + return field.ErrorList{field.Invalid(fieldPath, name, "path required")} } host, port, err := net.SplitHostPort(u.Host) @@ -395,33 +376,31 @@ func validateURL(name string, fieldPath *field.Path) field.ErrorList { host = u.Host } - allErrs = append(allErrs, validateSSLName(host, fieldPath)...) + allErrs := validateSSLName(host, fieldPath) if port != "" { allErrs = append(allErrs, validatePortNumber(port, fieldPath)...) } - return allErrs } func validateQueryString(queryString string, fieldPath *field.Path) field.ErrorList { - allErrs := field.ErrorList{} - _, err := url.ParseQuery(queryString) if err != nil { - return append(allErrs, field.Invalid(fieldPath, queryString, err.Error())) + return field.ErrorList{field.Invalid(fieldPath, queryString, err.Error())} } - - return allErrs + return nil } func validatePortNumber(port string, fieldPath *field.Path) field.ErrorList { - allErrs := field.ErrorList{} - portInt, _ := strconv.Atoi(port) + portInt, err := strconv.Atoi(port) + if err != nil { + return field.ErrorList{field.Invalid(fieldPath, port, "invalid port")} + } msg := validation.IsValidPortNum(portInt) if msg != nil { - allErrs = append(allErrs, field.Invalid(fieldPath, port, msg[0])) + return field.ErrorList{field.Invalid(fieldPath, port, msg[0])} } - return allErrs + return nil } func validateSSLName(name string, fieldPath *field.Path) field.ErrorList { @@ -443,11 +422,10 @@ var validateVerifyClientKeyParameters = map[string]bool{ } func validateIngressMTLSVerifyClient(verifyClient string, fieldPath *field.Path) field.ErrorList { - allErrs := field.ErrorList{} if verifyClient != "" { - allErrs = append(allErrs, ValidateParameter(verifyClient, validateVerifyClientKeyParameters, fieldPath)...) + return ValidateParameter(verifyClient, validateVerifyClientKeyParameters, fieldPath) } - return allErrs + return nil } const ( @@ -458,31 +436,24 @@ const ( var rateRegexp = regexp.MustCompile("^" + rateFmt + "$") func validateRate(rate string, fieldPath *field.Path) field.ErrorList { - allErrs := field.ErrorList{} - if rate == "" { - return append(allErrs, field.Required(fieldPath, "")) + return field.ErrorList{field.Required(fieldPath, "")} } - if !rateRegexp.MatchString(rate) { msg := validation.RegexError(rateErrMsg, rateFmt, "16r/s", "32r/m", "64r/s") - return append(allErrs, field.Invalid(fieldPath, rate, msg)) + return field.ErrorList{field.Invalid(fieldPath, rate, msg)} } - return allErrs + return nil } func validateRateLimitZoneSize(zoneSize string, fieldPath *field.Path) field.ErrorList { - allErrs := field.ErrorList{} - if zoneSize == "" { - return append(allErrs, field.Required(fieldPath, "")) + return field.ErrorList{field.Required(fieldPath, "")} } - allErrs = append(allErrs, validateSize(zoneSize, fieldPath)...) - + allErrs := validateSize(zoneSize, fieldPath) kbZoneSize := strings.TrimSuffix(strings.ToLower(zoneSize), "k") kbZoneSizeNum, err := strconv.Atoi(kbZoneSize) - mbZoneSize := strings.TrimSuffix(strings.ToLower(zoneSize), "m") mbZoneSizeNum, mbErr := strconv.Atoi(mbZoneSize) @@ -504,34 +475,30 @@ var rateLimitKeyVariables = map[string]bool{ } func validateRateLimitKey(key string, fieldPath *field.Path, isPlus bool) field.ErrorList { - allErrs := field.ErrorList{} - if key == "" { - return append(allErrs, field.Required(fieldPath, "")) + return field.ErrorList{field.Required(fieldPath, "")} } + allErrs := field.ErrorList{} if err := ValidateEscapedString(key, `Hello World! \n`, `\"${request_uri}\" is unavailable. \n`); err != nil { allErrs = append(allErrs, field.Invalid(fieldPath, key, err.Error())) } - allErrs = append(allErrs, validateStringWithVariables(key, fieldPath, rateLimitKeySpecialVariables, rateLimitKeyVariables, isPlus)...) - return allErrs } var jwtTokenSpecialVariables = []string{"arg_", "http_", "cookie_"} func validateJWTToken(token string, fieldPath *field.Path) field.ErrorList { - allErrs := field.ErrorList{} - if token == "" { - return allErrs + return nil } nginxVars := strings.Split(token, "$") if len(nginxVars) != 2 { - return append(allErrs, field.Invalid(fieldPath, token, "must have 1 var")) + return field.ErrorList{field.Invalid(fieldPath, token, "must have 1 var")} } + nVar := token[1:] special := false @@ -541,16 +508,12 @@ func validateJWTToken(token string, fieldPath *field.Path) field.ErrorList { break } } - if special { // validateJWTToken is called only when NGINX Plus is running - isPlus := true - allErrs = append(allErrs, validateSpecialVariable(nVar, fieldPath, isPlus)...) + return validateSpecialVariable(nVar, fieldPath, true) } else { - return append(allErrs, field.Invalid(fieldPath, token, "must only have special vars")) + return field.ErrorList{field.Invalid(fieldPath, token, "must only have special vars")} } - - return allErrs } var validLogLevels = map[string]bool{ @@ -561,14 +524,11 @@ var validLogLevels = map[string]bool{ } func validateRateLimitLogLevel(logLevel string, fieldPath *field.Path) field.ErrorList { - allErrs := field.ErrorList{} - if !validLogLevels[logLevel] { - allErrs = append(allErrs, field.Invalid(fieldPath, logLevel, fmt.Sprintf("Accepted values: %s", - mapToPrettyString(validLogLevels)))) + return field.ErrorList{field.Invalid(fieldPath, logLevel, fmt.Sprintf("Accepted values: %s", + mapToPrettyString(validLogLevels)))} } - - return allErrs + return nil } const ( @@ -579,40 +539,30 @@ const ( var realmFmtRegexp = regexp.MustCompile("^" + realmFmt + "$") func validateRealm(realm string, fieldPath *field.Path) field.ErrorList { - allErrs := field.ErrorList{} - if !realmFmtRegexp.MatchString(realm) { msg := validation.RegexError(realmFmtErrMsg, realmFmt, "MyAPI", "My Product API") - allErrs = append(allErrs, field.Invalid(fieldPath, realm, msg)) + return field.ErrorList{field.Invalid(fieldPath, realm, msg)} } - - return allErrs + return nil } func validateIPorCIDR(ipOrCIDR string, fieldPath *field.Path) field.ErrorList { - allErrs := field.ErrorList{} - _, _, err := net.ParseCIDR(ipOrCIDR) if err == nil { // valid CIDR - return allErrs + return nil } - ip := net.ParseIP(ipOrCIDR) if ip != nil { // valid IP - return allErrs + return nil } - - return append(allErrs, field.Invalid(fieldPath, ipOrCIDR, "must be a CIDR or IP")) + return field.ErrorList{field.Invalid(fieldPath, ipOrCIDR, "must be a CIDR or IP")} } func validatePositiveInt(n int, fieldPath *field.Path) field.ErrorList { - allErrs := field.ErrorList{} - if n <= 0 { - return append(allErrs, field.Invalid(fieldPath, n, "must be positive")) + return field.ErrorList{field.Invalid(fieldPath, n, "must be positive")} } - - return allErrs + return nil } diff --git a/pkg/apis/configuration/validation/policy_test.go b/pkg/apis/configuration/validation/policy_test.go index 48bbf441d6..c0babb99f2 100644 --- a/pkg/apis/configuration/validation/policy_test.go +++ b/pkg/apis/configuration/validation/policy_test.go @@ -309,6 +309,18 @@ func TestValidateAccessControlFails(t *testing.T) { } } +func TestValidateRate_ErrorsOnBogusRate(t *testing.T) { + t.Parallel() + + invalidRates := []string{"", "bogus"} + for _, v := range invalidRates { + allErrs := validateRate(v, field.NewPath("rate")) + if len(allErrs) == 0 { + t.Errorf("want err on invalid rate: %q, got nil", v) + } + } +} + func TestValidateRateLimit(t *testing.T) { t.Parallel() dryRun := true @@ -564,7 +576,7 @@ func TestValidateJWTFails(t *testing.T) { } } -func TestValidateIPorCIDR(t *testing.T) { +func TestValidateIPorCIDR_PassesOnValidInout(t *testing.T) { t.Parallel() validInput := []string{ "192.168.1.1", @@ -595,8 +607,27 @@ func TestValidateIPorCIDR(t *testing.T) { } } -func TestValidateRate(t *testing.T) { +func TestValidateIPorCIDR_FailsOnInvalidInput(t *testing.T) { t.Parallel() + + invalidInput := []string{ + "localhost", + "192.168.1.0/", + "2001:0db8:::1", + "2001:0db8::/", + } + + for _, input := range invalidInput { + allErrs := validateIPorCIDR(input, field.NewPath("ipOrCIDR")) + if len(allErrs) == 0 { + t.Errorf("validateIPorCIDR(%q) returned no errors for invalid input", input) + } + } +} + +func TestValidateRate_PassesOnValidInput(t *testing.T) { + t.Parallel() + validInput := []string{ "10r/s", "100r/m", @@ -609,6 +640,10 @@ func TestValidateRate(t *testing.T) { t.Errorf("validateRate(%q) returned errors %v for valid input", input, allErrs) } } +} + +func TestValidateRate_ErrorsOnInvalidInput(t *testing.T) { + t.Parallel() invalidInput := []string{ "10s", @@ -625,7 +660,7 @@ func TestValidateRate(t *testing.T) { } } -func TestValidatePositiveInt(t *testing.T) { +func TestValidatePositiveInt_PassesOnValidInput(t *testing.T) { t.Parallel() validInput := []int{1, 2} @@ -635,6 +670,10 @@ func TestValidatePositiveInt(t *testing.T) { t.Errorf("validatePositiveInt(%q) returned errors %v for valid input", input, allErrs) } } +} + +func TestValidatePositiveInt_ErrorsOnInvalidInput(t *testing.T) { + t.Parallel() invalidInput := []int{-1, 0} @@ -1132,6 +1171,18 @@ func TestValidateOIDCInvalid(t *testing.T) { } } +func TestValidatePortNumber_ErrorsOnInvalidPort(t *testing.T) { + t.Parallel() + + invalidPorts := []string{"bogus", ""} + for _, p := range invalidPorts { + allErrs := validatePortNumber(p, field.NewPath("port")) + if len(allErrs) == 0 { + t.Errorf("want err on invalid input %q, got nil", p) + } + } +} + func TestValidateClientID(t *testing.T) { t.Parallel() validInput := []string{"myid", "your.id", "id-sf-sjfdj.com", "foo_bar~vni"} @@ -1174,9 +1225,14 @@ func TestValidateOIDCScope(t *testing.T) { } } -func TestValidateURL(t *testing.T) { +func TestValidateURL_PassesOnValidInput(t *testing.T) { t.Parallel() - validInput := []string{"http://google.com/auth", "https://foo.bar/baz", "http://127.0.0.1/bar", "http://openid.connect.com:8080/foo"} + validInput := []string{ + "http://google.com/auth", + "https://foo.bar/baz", + "http://127.0.0.1/bar", + "http://openid.connect.com:8080/foo", + } for _, test := range validInput { allErrs := validateURL(test, field.NewPath("authEndpoint")) @@ -1184,8 +1240,21 @@ func TestValidateURL(t *testing.T) { t.Errorf("validateURL(%q) returned errors %v for valid input", allErrs, test) } } +} + +func TestValidateURL_ErrorsOnInvalidInput(t *testing.T) { + t.Parallel() - invalidInput := []string{"www.google..foo.com", "http://{foo.bar", `https://google.foo\bar`, "http://foo.bar:8080", "http://foo.bar:812345/fooo"} + invalidInput := []string{ + "www.google..foo.com", + "http://{foo.bar", + `https://google.foo\bar`, + "http://foo.bar:8080", + "http://foo.bar:812345/fooo", + "http://:812345/fooo", + "", + "bogusInput", + } for _, test := range invalidInput { allErrs := validateURL(test, field.NewPath("authEndpoint")) diff --git a/pkg/apis/configuration/validation/transportserver.go b/pkg/apis/configuration/validation/transportserver.go index 1eefd3697f..9111e4da73 100644 --- a/pkg/apis/configuration/validation/transportserver.go +++ b/pkg/apis/configuration/validation/transportserver.go @@ -35,9 +35,7 @@ func (tsv *TransportServerValidator) ValidateTransportServer(transportServer *v1 } func (tsv *TransportServerValidator) validateTransportServerSpec(spec *v1alpha1.TransportServerSpec, fieldPath *field.Path) field.ErrorList { - allErrs := field.ErrorList{} - - allErrs = append(allErrs, tsv.validateTransportListener(&spec.Listener, fieldPath.Child("listener"))...) + allErrs := tsv.validateTransportListener(&spec.Listener, fieldPath.Child("listener")) isTLSPassthroughListener := isPotentialTLSPassthroughListener(&spec.Listener) allErrs = append(allErrs, validateTransportServerHost(spec.Host, fieldPath.Child("host"), isTLSPassthroughListener)...) @@ -65,42 +63,32 @@ func (tsv *TransportServerValidator) validateTransportServerSpec(spec *v1alpha1. } func validateTLS(tls *v1alpha1.TLS, isTLSPassthrough bool, fieldPath *field.Path) field.ErrorList { - allErrs := field.ErrorList{} - if tls == nil { - return allErrs + return nil } - if isTLSPassthrough { - return append(allErrs, field.Forbidden(fieldPath, "cannot specify secret for tls passthrough")) + return field.ErrorList{field.Forbidden(fieldPath, "cannot specify secret for tls passthrough")} } - if tls.Secret == "" { - return append(allErrs, field.Required(fieldPath, "must specify secret for tls")) + return field.ErrorList{field.Required(fieldPath, "must specify secret for tls")} } - - return append(allErrs, validateSecretName(tls.Secret, fieldPath.Child("secret"))...) + return validateSecretName(tls.Secret, fieldPath.Child("secret")) } func validateSnippets(serverSnippet string, fieldPath *field.Path, snippetsEnabled bool) field.ErrorList { - allErrs := field.ErrorList{} if !snippetsEnabled && serverSnippet != "" { - return append(allErrs, field.Forbidden(fieldPath, "snippet specified but snippets feature is not enabled")) + return field.ErrorList{field.Forbidden(fieldPath, "snippet specified but snippets feature is not enabled")} } - - return allErrs + return nil } func validateTransportServerHost(host string, fieldPath *field.Path, isTLSPassthroughListener bool) field.ErrorList { - allErrs := field.ErrorList{} - if !isTLSPassthroughListener { if host != "" { - return append(allErrs, field.Forbidden(fieldPath, "host field is allowed only for TLS Passthrough TransportServers")) + return field.ErrorList{field.Forbidden(fieldPath, "host field is allowed only for TLS Passthrough TransportServers")} } - return allErrs + return nil } - return validateHost(host, fieldPath) } @@ -113,11 +101,8 @@ func (tsv *TransportServerValidator) validateTransportListener(listener *v1alpha } func validateRegularListener(listener *v1alpha1.TransportServerListener, fieldPath *field.Path) field.ErrorList { - allErrs := field.ErrorList{} - - allErrs = append(allErrs, validateListenerName(listener.Name, fieldPath.Child("name"))...) + allErrs := validateListenerName(listener.Name, fieldPath.Child("name")) allErrs = append(allErrs, validateListenerProtocol(listener.Protocol, fieldPath.Child("protocol"))...) - return allErrs } @@ -126,23 +111,18 @@ func isPotentialTLSPassthroughListener(listener *v1alpha1.TransportServerListene } func (tsv *TransportServerValidator) validateTLSPassthroughListener(listener *v1alpha1.TransportServerListener, fieldPath *field.Path) field.ErrorList { - allErrs := field.ErrorList{} - if !tsv.tlsPassthrough { - return append(allErrs, field.Forbidden(fieldPath, "TLS Passthrough is not enabled")) + return field.ErrorList{field.Forbidden(fieldPath, "TLS Passthrough is not enabled")} } - if listener.Name == v1alpha1.TLSPassthroughListenerName && listener.Protocol != v1alpha1.TLSPassthroughListenerProtocol { msg := fmt.Sprintf("must be '%s' for the built-in %s listener", v1alpha1.TLSPassthroughListenerProtocol, v1alpha1.TLSPassthroughListenerName) - return append(allErrs, field.Invalid(fieldPath.Child("protocol"), listener.Protocol, msg)) + return field.ErrorList{field.Invalid(fieldPath.Child("protocol"), listener.Protocol, msg)} } - if listener.Protocol == v1alpha1.TLSPassthroughListenerProtocol && listener.Name != v1alpha1.TLSPassthroughListenerName { msg := fmt.Sprintf("must be '%s' for a listener with the protocol %s", v1alpha1.TLSPassthroughListenerName, v1alpha1.TLSPassthroughListenerProtocol) - return append(allErrs, field.Invalid(fieldPath.Child("name"), listener.Name, msg)) + return field.ErrorList{field.Invalid(fieldPath.Child("name"), listener.Name, msg)} } - - return allErrs + return nil } func validateListenerName(name string, fieldPath *field.Path) field.ErrorList { @@ -156,19 +136,15 @@ var listenerProtocols = map[string]bool{ } func validateListenerProtocol(protocol string, fieldPath *field.Path) field.ErrorList { - allErrs := field.ErrorList{} - if protocol == "" { msg := fmt.Sprintf("must specify protocol. Accepted values: %s", mapToPrettyString(listenerProtocols)) - return append(allErrs, field.Required(fieldPath, msg)) + return field.ErrorList{field.Required(fieldPath, msg)} } - if !listenerProtocols[protocol] { msg := fmt.Sprintf("invalid protocol. Accepted values: %s", mapToPrettyString(listenerProtocols)) - allErrs = append(allErrs, field.Invalid(fieldPath, protocol, msg)) + return field.ErrorList{field.Invalid(fieldPath, protocol, msg)} } - - return allErrs + return nil } func validateTransportServerUpstreams(upstreams []v1alpha1.Upstream, fieldPath *field.Path, isPlus bool) (allErrs field.ErrorList, upstreamNames sets.Set[string]) { @@ -205,13 +181,11 @@ func validateTransportServerUpstreams(upstreams []v1alpha1.Upstream, fieldPath * } func validateLoadBalancingMethod(method string, fieldPath *field.Path, isPlus bool) field.ErrorList { - allErrs := field.ErrorList{} if method == "" { - return allErrs + return nil } method = strings.TrimSpace(method) - if strings.HasPrefix(method, "hash") { return validateHashLoadBalancingMethod(method, fieldPath, isPlus) } @@ -220,12 +194,10 @@ func validateLoadBalancingMethod(method string, fieldPath *field.Path, isPlus bo if isPlus { validMethodValues = nginxPlusStreamLoadBalanceValidInput } - if _, exists := validMethodValues[method]; !exists { - return append(allErrs, field.Invalid(fieldPath, method, fmt.Sprintf("load balancing method is not valid: %v", method))) + return field.ErrorList{field.Invalid(fieldPath, method, fmt.Sprintf("load balancing method is not valid: %v", method))} } - - return allErrs + return nil } var nginxStreamLoadBalanceValidInput = map[string]bool{ @@ -256,35 +228,30 @@ var loadBalancingVariables = map[string]bool{ var hashMethodRegexp = regexp.MustCompile(`^hash (\S+)(?: consistent)?$`) func validateHashLoadBalancingMethod(method string, fieldPath *field.Path, isPlus bool) field.ErrorList { - allErrs := field.ErrorList{} matches := hashMethodRegexp.FindStringSubmatch(method) if len(matches) != 2 { msg := fmt.Sprintf("invalid value for load balancing method: %v", method) - return append(allErrs, field.Invalid(fieldPath, method, msg)) + return field.ErrorList{field.Invalid(fieldPath, method, msg)} } + allErrs := field.ErrorList{} hashKey := matches[1] if strings.Contains(hashKey, "$") { varErrs := validateStringWithVariables(hashKey, fieldPath, []string{}, loadBalancingVariables, isPlus) allErrs = append(allErrs, varErrs...) } - if err := ValidateEscapedString(method); err != nil { msg := fmt.Sprintf("invalid value for hash: %v", err) return append(allErrs, field.Invalid(fieldPath, method, msg)) } - return allErrs } func validateTSUpstreamHealthChecks(hc *v1alpha1.HealthCheck, fieldPath *field.Path) field.ErrorList { - allErrs := field.ErrorList{} - if hc == nil { - return allErrs + return nil } - - allErrs = append(allErrs, validateTime(hc.Timeout, fieldPath.Child("timeout"))...) + allErrs := validateTime(hc.Timeout, fieldPath.Child("timeout")) allErrs = append(allErrs, validateTime(hc.Interval, fieldPath.Child("interval"))...) allErrs = append(allErrs, validateTime(hc.Jitter, fieldPath.Child("jitter"))...) allErrs = append(allErrs, validatePositiveIntOrZero(hc.Fails, fieldPath.Child("fails"))...) @@ -295,30 +262,26 @@ func validateTSUpstreamHealthChecks(hc *v1alpha1.HealthCheck, fieldPath *field.P allErrs = append(allErrs, field.Invalid(fieldPath.Child("port"), hc.Port, msg)) } } - allErrs = append(allErrs, validateHealthCheckMatch(hc.Match, fieldPath.Child("match"))...) - return allErrs } func validateHealthCheckMatch(match *v1alpha1.Match, fieldPath *field.Path) field.ErrorList { - allErrs := field.ErrorList{} if match == nil { - return allErrs + return nil } - allErrs = append(allErrs, validateMatchExpect(match.Expect, fieldPath.Child("expect"))...) + + allErrs := validateMatchExpect(match.Expect, fieldPath.Child("expect")) allErrs = append(allErrs, validateMatchSend(match.Expect, fieldPath.Child("send"))...) return allErrs } func validateMatchExpect(expect string, fieldPath *field.Path) field.ErrorList { - allErrs := field.ErrorList{} if expect == "" { - return allErrs + return nil } - if err := ValidateEscapedString(expect); err != nil { - return append(allErrs, field.Invalid(fieldPath, expect, err.Error())) + return field.ErrorList{field.Invalid(fieldPath, expect, err.Error())} } if strings.HasPrefix(expect, "~") { @@ -331,33 +294,30 @@ func validateMatchExpect(expect string, fieldPath *field.Path) field.ErrorList { // compile also validates hex literals if _, err := regexp.Compile(expr); err != nil { - return append(allErrs, field.Invalid(fieldPath, expr, fmt.Sprintf("must be a valid regular expression: %v", err))) + return field.ErrorList{field.Invalid(fieldPath, expr, fmt.Sprintf("must be a valid regular expression: %v", err))} } } else { if err := validateHexString(expect); err != nil { - return append(allErrs, field.Invalid(fieldPath, expect, err.Error())) + return field.ErrorList{field.Invalid(fieldPath, expect, err.Error())} } } - return allErrs + return nil } func validateMatchSend(send string, fieldPath *field.Path) field.ErrorList { - allErrs := field.ErrorList{} if send == "" { - return allErrs + return nil } if err := ValidateEscapedString(send); err != nil { - return append(allErrs, field.Invalid(fieldPath, send, err.Error())) + return field.ErrorList{field.Invalid(fieldPath, send, err.Error())} } - err := validateHexString(send) - if err != nil { - return append(allErrs, field.Invalid(fieldPath, send, err.Error())) + if err := validateHexString(send); err != nil { + return field.ErrorList{field.Invalid(fieldPath, send, err.Error())} } - - return allErrs + return nil } var hexLiteralRegexp = regexp.MustCompile(`\\x(.{0,2})`) @@ -382,49 +342,35 @@ func validateHexString(s string) error { } func validateTransportServerUpstreamParameters(upstreamParameters *v1alpha1.UpstreamParameters, fieldPath *field.Path, protocol string) field.ErrorList { - allErrs := field.ErrorList{} - if upstreamParameters == nil { - return allErrs + return nil } - allErrs = append(allErrs, validateUDPUpstreamParameter(upstreamParameters.UDPRequests, fieldPath.Child("udpRequests"), protocol)...) + allErrs := validateUDPUpstreamParameter(upstreamParameters.UDPRequests, fieldPath.Child("udpRequests"), protocol) allErrs = append(allErrs, validateUDPUpstreamParameter(upstreamParameters.UDPResponses, fieldPath.Child("udpResponses"), protocol)...) allErrs = append(allErrs, validateTime(upstreamParameters.ConnectTimeout, fieldPath.Child("connectTimeout"))...) allErrs = append(allErrs, validateTime(upstreamParameters.NextUpstreamTimeout, fieldPath.Child("nextUpstreamTimeout"))...) allErrs = append(allErrs, validatePositiveIntOrZero(upstreamParameters.NextUpstreamTries, fieldPath.Child("nextUpstreamTries"))...) - return allErrs } func validateSessionParameters(sessionParameters *v1alpha1.SessionParameters, fieldPath *field.Path) field.ErrorList { - allErrs := field.ErrorList{} - if sessionParameters == nil { - return allErrs + return nil } - - allErrs = append(allErrs, validateTime(sessionParameters.Timeout, fieldPath.Child("timeout"))...) - - return allErrs + return validateTime(sessionParameters.Timeout, fieldPath.Child("timeout")) } func validateUDPUpstreamParameter(parameter *int, fieldPath *field.Path, protocol string) field.ErrorList { - allErrs := field.ErrorList{} - if parameter != nil && protocol != "UDP" { - return append(allErrs, field.Forbidden(fieldPath, "is not allowed for non-UDP TransportServers")) + return field.ErrorList{field.Forbidden(fieldPath, "is not allowed for non-UDP TransportServers")} } - return validatePositiveIntOrZeroFromPointer(parameter, fieldPath) } func validateTransportServerAction(action *v1alpha1.Action, fieldPath *field.Path, upstreamNames sets.Set[string]) field.ErrorList { - allErrs := field.ErrorList{} - if action.Pass == "" { - return append(allErrs, field.Required(fieldPath, "must specify pass")) + return field.ErrorList{field.Required(fieldPath, "must specify pass")} } - return validateReferencedUpstream(action.Pass, fieldPath.Child("pass"), upstreamNames) } diff --git a/pkg/apis/configuration/validation/virtualserver.go b/pkg/apis/configuration/validation/virtualserver.go index 1ac7be51e8..434c49223c 100644 --- a/pkg/apis/configuration/validation/virtualserver.go +++ b/pkg/apis/configuration/validation/virtualserver.go @@ -95,12 +95,11 @@ func (vsv *VirtualServerValidator) validateVirtualServerSpec(spec *v1.VirtualSer const wildcardPrefix = "*." func validateHost(host string, fieldPath *field.Path) field.ErrorList { - allErrs := field.ErrorList{} - if host == "" { - return append(allErrs, field.Required(fieldPath, "")) + return field.ErrorList{field.Required(fieldPath, "")} } + allErrs := field.ErrorList{} if strings.HasPrefix(host, wildcardPrefix) { for _, msg := range validation.IsWildcardDNS1123Subdomain(host) { allErrs = append(allErrs, field.Invalid(fieldPath, host, msg)) @@ -153,83 +152,67 @@ func validatePolicies(policies []v1.PolicyReference, fieldPath *field.Path, name } func (vsv *VirtualServerValidator) validateTLS(tls *v1.TLS, fieldPath *field.Path) field.ErrorList { - allErrs := field.ErrorList{} - if tls == nil { // valid case - tls is not defined - return allErrs + return nil } - allErrs = append(allErrs, validateSecretName(tls.Secret, fieldPath.Child("secret"))...) - + allErrs := validateSecretName(tls.Secret, fieldPath.Child("secret")) allErrs = append(allErrs, validateTLSRedirect(tls.Redirect, fieldPath.Child("redirect"))...) - allErrs = append(allErrs, validateTLSCmFields(tls.CertManager, vsv.isCertManagerEnabled, tls.Secret, fieldPath.Child("cert-manager"))...) - return allErrs } func validateTLSCmFields(cm *v1.CertManager, isCertManagerEnabled bool, secret string, fieldPath *field.Path) field.ErrorList { - allErrs := field.ErrorList{} - if cm == nil { // valid, cert-manager is not required - return allErrs + return nil } + allErrs := field.ErrorList{} if !isCertManagerEnabled { allErrs = append(allErrs, field.Forbidden(fieldPath, "field requires cert-manager enablement")) } - if secret == "" { // invalid, secret name is required for cert-manager configuration allErrs = append(allErrs, field.Forbidden(fieldPath, "field requires TLS.Secret to be specified")) } - return allErrs } func validateDos(isDosEnabled bool, dos string, fieldPath *field.Path) field.ErrorList { - allErrs := field.ErrorList{} - if dos == "" { // valid, dos is not required - return allErrs + return nil } + allErrs := field.ErrorList{} if !isDosEnabled { allErrs = append(allErrs, field.Forbidden(fieldPath, "field requires DOS enablement")) } - for _, msg := range validation.IsQualifiedName(dos) { allErrs = append(allErrs, field.Invalid(fieldPath, dos, msg)) } - return allErrs } func (vsv *VirtualServerValidator) validateExternalDNS(ed *v1.ExternalDNS, fieldPath *field.Path) field.ErrorList { - allErrs := field.ErrorList{} - if ed == nil || !ed.Enable { // valid, externalDNS is not required - return allErrs + return nil } - if !vsv.isExternalDNSEnabled { - allErrs = append(allErrs, field.Forbidden(fieldPath, "field requires externalDNS enablement")) + return field.ErrorList{field.Forbidden(fieldPath, "field requires externalDNS enablement")} } - - return allErrs + return nil } func validateTLSRedirect(redirect *v1.TLSRedirect, fieldPath *field.Path) field.ErrorList { - allErrs := field.ErrorList{} - if redirect == nil { - return allErrs + return nil } + allErrs := field.ErrorList{} if redirect.Code != nil { allErrs = append(allErrs, validateRedirectStatusCode(*redirect.Code, fieldPath.Child("code"))...) } @@ -237,7 +220,6 @@ func validateTLSRedirect(redirect *v1.TLSRedirect, fieldPath *field.Path) field. if redirect.BasedOn != "" && redirect.BasedOn != "scheme" && redirect.BasedOn != "x-forwarded-proto" { allErrs = append(allErrs, field.Invalid(fieldPath.Child("basedOn"), redirect.BasedOn, "accepted values are 'scheme', 'x-forwarded-proto'")) } - return allErrs } @@ -249,45 +231,35 @@ var validRedirectStatusCodes = map[int]bool{ } func validateRedirectStatusCode(code int, fieldPath *field.Path) field.ErrorList { - allErrs := field.ErrorList{} - if _, ok := validRedirectStatusCodes[code]; !ok { - allErrs = append(allErrs, field.Invalid(fieldPath, code, "status code out of accepted range. accepted values are '301', '302', '307', '308'")) + return field.ErrorList{field.Invalid(fieldPath, code, "status code out of accepted range. accepted values are '301', '302', '307', '308'")} } - - return allErrs + return nil } func validatePositiveIntOrZero(n int, fieldPath *field.Path) field.ErrorList { - allErrs := field.ErrorList{} - if n < 0 { - return append(allErrs, field.Invalid(fieldPath, n, "must be positive")) + return field.ErrorList{field.Invalid(fieldPath, n, "must be positive")} } - - return allErrs + return nil } func validatePositiveIntOrZeroFromPointer(n *int, fieldPath *field.Path) field.ErrorList { - allErrs := field.ErrorList{} if n == nil { - return allErrs + return nil } - if *n < 0 { - return append(allErrs, field.Invalid(fieldPath, n, "must be positive or zero")) + return field.ErrorList{field.Invalid(fieldPath, n, "must be positive or zero")} } - - return allErrs + return nil } func validateBuffer(buff *v1.UpstreamBuffers, fieldPath *field.Path) field.ErrorList { - allErrs := field.ErrorList{} - if buff == nil { - return allErrs + return nil } + allErrs := field.ErrorList{} if buff.Number <= 0 { allErrs = append(allErrs, field.Invalid(fieldPath.Child("number"), buff.Number, "must be positive")) } @@ -297,40 +269,34 @@ func validateBuffer(buff *v1.UpstreamBuffers, fieldPath *field.Path) field.Error } else { allErrs = append(allErrs, validateSize(buff.Size, fieldPath.Child("size"))...) } - return allErrs } func validateUpstreamLBMethod(lBMethod string, fieldPath *field.Path, isPlus bool) field.ErrorList { - allErrs := field.ErrorList{} if lBMethod == "" { - return allErrs + return nil } if isPlus { _, err := configs.ParseLBMethodForPlus(lBMethod) if err != nil { - return append(allErrs, field.Invalid(fieldPath, lBMethod, err.Error())) + return field.ErrorList{field.Invalid(fieldPath, lBMethod, err.Error())} } } else { _, err := configs.ParseLBMethod(lBMethod) if err != nil { - return append(allErrs, field.Invalid(fieldPath, lBMethod, err.Error())) + return field.ErrorList{field.Invalid(fieldPath, lBMethod, err.Error())} } } - - return allErrs + return nil } func validateUpstreamHealthCheck(hc *v1.HealthCheck, typeName string, fieldPath *field.Path) field.ErrorList { - allErrs := field.ErrorList{} - if hc == nil { - return allErrs + return nil } - allErrs = append(allErrs, validateGrpcHealthCheck(hc, typeName, fieldPath)...) - + allErrs := validateGrpcHealthCheck(hc, typeName, fieldPath) if hc.Path != "" { allErrs = append(allErrs, validatePath(hc.Path, fieldPath.Child("path"))...) } @@ -403,12 +369,11 @@ func validateGrpcStatus(i *int, fieldPath *field.Path) field.ErrorList { } func validateSessionCookie(sc *v1.SessionCookie, fieldPath *field.Path) field.ErrorList { - allErrs := field.ErrorList{} - if sc == nil { - return allErrs + return nil } + allErrs := field.ErrorList{} if sc.Name == "" { allErrs = append(allErrs, field.Required(fieldPath.Child("name"), "")) } else { @@ -439,26 +404,24 @@ func validateSessionCookie(sc *v1.SessionCookie, fieldPath *field.Path) field.Er // validateUpstreamType validates that the protocol type of the upstream is of a supported protocol. // Current supported protocols are "http" and "grpc". If unset, it will default to "http". func validateUpstreamType(typeName string, fieldPath *field.Path) field.ErrorList { - allErrs := field.ErrorList{} - if typeName == "" { - return allErrs + return nil } - if typeName != "grpc" && typeName != "http" { - allErrs = append(allErrs, field.Invalid(fieldPath, typeName, "must be one of `grpc` or `http`")) + switch typeName { + case "grpc", "http": + return nil + default: + return field.ErrorList{field.Invalid(fieldPath, typeName, "must be one of `grpc` or `http`")} } - - return allErrs } func validateStatusMatch(s string, fieldPath *field.Path) field.ErrorList { - allErrs := field.ErrorList{} - if s == "" { - return allErrs + return nil } + allErrs := field.ErrorList{} if strings.HasPrefix(s, "!") { if !strings.HasPrefix(s, "! ") { allErrs = append(allErrs, field.Invalid(fieldPath, s, "must have an space character after the `!`")) @@ -479,7 +442,6 @@ func validateStatusMatch(s string, fieldPath *field.Path) field.ErrorList { allErrs = append(allErrs, field.Invalid(fieldPath, s, msg)) } } - return allErrs } @@ -644,11 +606,13 @@ var validNextUpstreamParams = map[string]bool{ // validateNextUpstream checks the values given for passing queries to a upstream func validateNextUpstream(nextUpstream string, fieldPath *field.Path) field.ErrorList { - allErrs := field.ErrorList{} - allParams := sets.Set[string]{} if nextUpstream == "" { - return allErrs + return nil } + + allErrs := field.ErrorList{} + allParams := sets.Set[string]{} + params := strings.Fields(nextUpstream) for _, para := range params { if !validNextUpstreamParams[para] { @@ -678,16 +642,14 @@ func validateServiceName(name string, fieldPath *field.Path) field.ErrorList { } func validateDNS1035Label(name string, fieldPath *field.Path) field.ErrorList { - allErrs := field.ErrorList{} - if name == "" { - return append(allErrs, field.Required(fieldPath, "")) + return field.ErrorList{field.Required(fieldPath, "")} } + allErrs := field.ErrorList{} for _, msg := range validation.IsDNS1035Label(name) { allErrs = append(allErrs, field.Invalid(fieldPath, name, msg)) } - return allErrs } @@ -714,9 +676,7 @@ func (vsv *VirtualServerValidator) validateVirtualServerRoutes(routes []v1.Route } func (vsv *VirtualServerValidator) validateRoute(route v1.Route, fieldPath *field.Path, upstreamNames sets.Set[string], isRouteFieldForbidden bool, namespace string) field.ErrorList { - allErrs := field.ErrorList{} - - allErrs = append(allErrs, validateRoutePath(route.Path, fieldPath.Child("path"))...) + allErrs := validateRoutePath(route.Path, fieldPath.Child("path")) allErrs = append(allErrs, validatePolicies(route.Policies, fieldPath.Child("policies"), namespace)...) fieldCount := 0 @@ -780,16 +740,14 @@ func errorPageHasRequiredFields(errorPage v1.ErrorPage) bool { } func (vsv *VirtualServerValidator) validateErrorPage(errorPage v1.ErrorPage, fieldPath *field.Path) field.ErrorList { - allErrs := field.ErrorList{} - if !errorPageHasRequiredFields(errorPage) { - return append(allErrs, field.Required(fieldPath, "must specify exactly one of `redirect` or `return`")) + return field.ErrorList{field.Required(fieldPath, "must specify exactly one of `redirect` or `return`")} } - if len(errorPage.Codes) == 0 { - return append(allErrs, field.Required(fieldPath.Child("codes"), "must include at least 1 status code in `codes`")) + return field.ErrorList{field.Required(fieldPath.Child("codes"), "must include at least 1 status code in `codes`")} } + allErrs := field.ErrorList{} for i, c := range errorPage.Codes { for _, msg := range validation.IsInRange(c, 300, 599) { allErrs = append(allErrs, field.Invalid(fieldPath.Child("codes").Index(i), c, msg)) @@ -803,21 +761,17 @@ func (vsv *VirtualServerValidator) validateErrorPage(errorPage v1.ErrorPage, fie if errorPage.Redirect != nil { allErrs = append(allErrs, vsv.validateErrorPageRedirect(errorPage.Redirect, fieldPath.Child("redirect"))...) } - return allErrs } var errorPageReturnBodyVariable = map[string]bool{"upstream_status": true} func (vsv *VirtualServerValidator) validateErrorPageReturn(r *v1.ErrorPageReturn, fieldPath *field.Path) field.ErrorList { - allErrs := field.ErrorList{} - - allErrs = append(allErrs, vsv.validateActionReturn(&r.ActionReturn, fieldPath, nil, errorPageReturnBodyVariable)...) + allErrs := vsv.validateActionReturn(&r.ActionReturn, fieldPath, nil, errorPageReturnBodyVariable) for i, header := range r.Headers { allErrs = append(allErrs, vsv.validateErrorPageHeader(header, fieldPath.Child("headers").Index(i))...) } - return allErrs } @@ -846,11 +800,7 @@ func (vsv *VirtualServerValidator) validateErrorPageHeader(h v1.Header, fieldPat var validErrorPageRedirectVariables = map[string]bool{"scheme": true, "http_x_forwarded_proto": true} func (vsv *VirtualServerValidator) validateErrorPageRedirect(r *v1.ErrorPageRedirect, fieldPath *field.Path) field.ErrorList { - allErrs := field.ErrorList{} - - allErrs = append(allErrs, vsv.validateActionRedirect(&r.ActionRedirect, fieldPath, validErrorPageRedirectVariables)...) - - return allErrs + return vsv.validateActionRedirect(&r.ActionRedirect, fieldPath, validErrorPageRedirectVariables) } func countActions(action *v1.Action) int { @@ -912,12 +862,11 @@ var validRedirectVariableNames = map[string]bool{ } func (vsv *VirtualServerValidator) validateAction(action *v1.Action, fieldPath *field.Path, upstreamNames sets.Set[string], path string, internal bool) field.ErrorList { - allErrs := field.ErrorList{} - if countActions(action) != 1 { - return append(allErrs, field.Required(fieldPath, "action must specify exactly one of `pass`, `redirect`, `return` or `proxy`")) + return field.ErrorList{field.Required(fieldPath, "action must specify exactly one of `pass`, `redirect`, `return` or `proxy`")} } + allErrs := field.ErrorList{} if action.Pass != "" { allErrs = append(allErrs, validateReferencedUpstream(action.Pass, fieldPath.Child("pass"), upstreamNames)...) } @@ -938,14 +887,11 @@ func (vsv *VirtualServerValidator) validateAction(action *v1.Action, fieldPath * } func (vsv *VirtualServerValidator) validateActionRedirect(redirect *v1.ActionRedirect, fieldPath *field.Path, validVars map[string]bool) field.ErrorList { - allErrs := field.ErrorList{} - - allErrs = append(allErrs, vsv.validateRedirectURL(redirect.URL, fieldPath.Child("url"), validVars)...) + allErrs := vsv.validateRedirectURL(redirect.URL, fieldPath.Child("url"), validVars) if redirect.Code != 0 { allErrs = append(allErrs, validateRedirectStatusCode(redirect.Code, fieldPath.Child("code"))...) } - return allErrs } @@ -964,53 +910,38 @@ func captureVariables(s string) []string { } func (vsv *VirtualServerValidator) validateRedirectURL(redirectURL string, fieldPath *field.Path, validVars map[string]bool) field.ErrorList { - allErrs := field.ErrorList{} - if redirectURL == "" { - return append(allErrs, field.Required(fieldPath, "must specify a url")) + return field.ErrorList{field.Required(fieldPath, "must specify a url")} } - if !strings.Contains(redirectURL, "://") { - return append(allErrs, field.Invalid(fieldPath, redirectURL, "must contain the protocol with '://', for example http://, https:// or ${scheme}://")) + return field.ErrorList{field.Invalid(fieldPath, redirectURL, "must contain the protocol with '://', for example http://, https:// or ${scheme}://")} } - if err := ValidateEscapedString(redirectURL, "http://www.nginx.com", "${scheme}://${host}/green/", `\"http://www.nginx.com\"`); err != nil { - return append(allErrs, field.Invalid(fieldPath, redirectURL, err.Error())) + return field.ErrorList{field.Invalid(fieldPath, redirectURL, err.Error())} } - - allErrs = append(allErrs, validateStringWithVariables(redirectURL, fieldPath, nil, validVars, vsv.isPlus)...) - - return allErrs + return validateStringWithVariables(redirectURL, fieldPath, nil, validVars, vsv.isPlus) } func validateActionReturnCode(code int, fieldPath *field.Path) field.ErrorList { - allErrs := field.ErrorList{} - if (code >= 200 && code <= 299) || (code >= 400 && code <= 599) { - return allErrs + return nil } - msg := "must be a valid status code either 2XX, 4XX or 5XX, for example, 200 or 402." - return append(allErrs, field.Invalid(fieldPath, code, msg)) + return field.ErrorList{field.Invalid(fieldPath, code, msg)} } func (vsv *VirtualServerValidator) validateActionReturn(r *v1.ActionReturn, fieldPath *field.Path, specialValidVars []string, validVars map[string]bool) field.ErrorList { - allErrs := field.ErrorList{} - if r.Body == "" { - return append(allErrs, field.Required(fieldPath.Child("body"), "")) + return field.ErrorList{field.Required(fieldPath.Child("body"), "")} } - allErrs = append(allErrs, validateEscapedStringWithVariables(r.Body, fieldPath.Child("body"), specialValidVars, validVars, vsv.isPlus)...) - + allErrs := validateEscapedStringWithVariables(r.Body, fieldPath.Child("body"), specialValidVars, validVars, vsv.isPlus) if r.Type != "" { allErrs = append(allErrs, validateActionReturnType(r.Type, fieldPath.Child("type"))...) } - if r.Code != 0 { allErrs = append(allErrs, validateActionReturnCode(r.Code, fieldPath.Child("code"))...) } - return allErrs } @@ -1034,14 +965,11 @@ var ( var actionReturnTypeRegexp = regexp.MustCompile("^" + actionReturnTypeFmt + "$") func validateActionReturnType(returnType string, fieldPath *field.Path) field.ErrorList { - allErrs := field.ErrorList{} - if !actionReturnTypeRegexp.MatchString(returnType) { msg := validation.RegexError(actionReturnTypeErr, actionReturnTypeFmt, "type/subtype", "application/json") - allErrs = append(allErrs, field.Invalid(fieldPath, returnType, msg)) + return field.ErrorList{field.Invalid(fieldPath, returnType, msg)} } - - return allErrs + return nil } func validateRouteField(value string, fieldPath *field.Path) field.ErrorList { @@ -1068,9 +996,7 @@ func validateReferencedUpstream(name string, fieldPath *field.Path, upstreamName } func (vsv *VirtualServerValidator) validateActionProxy(p *v1.ActionProxy, fieldPath *field.Path, upstreamNames sets.Set[string], path string, internal bool) field.ErrorList { - allErrs := field.ErrorList{} - - allErrs = append(allErrs, validateReferencedUpstream(p.Upstream, fieldPath.Child("upstream"), upstreamNames)...) + allErrs := validateReferencedUpstream(p.Upstream, fieldPath.Child("upstream"), upstreamNames) allErrs = append(allErrs, vsv.validateActionProxyRequestHeaders(p.RequestHeaders, fieldPath.Child("requestHeaders"))...) allErrs = append(allErrs, vsv.validateActionProxyResponseHeaders(p.ResponseHeaders, fieldPath.Child("responseHeaders"))...) @@ -1084,45 +1010,34 @@ func (vsv *VirtualServerValidator) validateActionProxy(p *v1.ActionProxy, fieldP } func validateStringNoVariables(s string, fieldPath *field.Path) field.ErrorList { - allErrs := field.ErrorList{} - for i, char := range s { charLen := len(string(char)) if string(char) == "$" && i+charLen < len(s) { if _, err := strconv.Atoi(string(s[i+charLen])); err != nil { - return append(allErrs, field.Invalid(fieldPath, s, "`$` character can be only followed by a number")) + return field.ErrorList{field.Invalid(fieldPath, s, "`$` character can be only followed by a number")} } } } - - return allErrs + return nil } func validateActionProxyRewritePath(rewritePath string, fieldPath *field.Path) field.ErrorList { - allErrs := field.ErrorList{} - if rewritePath == "" { - return allErrs + return nil } - - allErrs = append(allErrs, validateStringNoVariables(rewritePath, fieldPath)...) - + allErrs := validateStringNoVariables(rewritePath, fieldPath) return append(allErrs, validatePath(rewritePath, fieldPath)...) } func validateActionProxyRewritePathForRegexp(rewritePath string, fieldPath *field.Path) field.ErrorList { - allErrs := field.ErrorList{} - if rewritePath == "" { - return allErrs + return nil } - allErrs = append(allErrs, validateStringNoVariables(rewritePath, fieldPath)...) - + allErrs := validateStringNoVariables(rewritePath, fieldPath) if err := ValidateEscapedString(rewritePath, "/rewrite$1", "/images"); err != nil { allErrs = append(allErrs, field.Invalid(fieldPath, rewritePath, err.Error())) } - return allErrs } @@ -1193,26 +1108,23 @@ func (vsv *VirtualServerValidator) validateActionProxyHeader(h v1.Header, fieldP } func (vsv *VirtualServerValidator) validateActionProxyRequestHeaders(requestHeaders *v1.ProxyRequestHeaders, fieldPath *field.Path) field.ErrorList { - allErrs := field.ErrorList{} - if requestHeaders == nil { - return allErrs + return nil } + allErrs := field.ErrorList{} for i, header := range requestHeaders.Set { allErrs = append(allErrs, vsv.validateActionProxyHeader(header, fieldPath.Index(i))...) } - return allErrs } func (vsv *VirtualServerValidator) validateActionProxyResponseHeaders(responseHeaders *v1.ProxyResponseHeaders, fieldPath *field.Path) field.ErrorList { - allErrs := field.ErrorList{} - if responseHeaders == nil { - return allErrs + return nil } + allErrs := field.ErrorList{} for i, header := range responseHeaders.Hide { for _, msg := range validation.IsHTTPHeaderName(header) { allErrs = append(allErrs, field.Invalid(fieldPath.Child("hide").Index(i), header, msg)) @@ -1247,30 +1159,27 @@ var validIgnoreHeaders = map[string]bool{ } func validateIgnoreHeaders(ignoreHeaders []string, fieldPath *field.Path) field.ErrorList { - allErrs := field.ErrorList{} if len(ignoreHeaders) == 0 { - return allErrs + return nil } + allErrs := field.ErrorList{} for i, h := range ignoreHeaders { if !validIgnoreHeaders[h] { msg := fmt.Sprintf("not a valid ignore header name. Accepted headers are : %v", mapToPrettyString(validIgnoreHeaders)) allErrs = append(allErrs, field.Invalid(fieldPath.Index(i), h, msg)) } } - return allErrs } func (vsv *VirtualServerValidator) validateSplits(splits []v1.Split, fieldPath *field.Path, upstreamNames sets.Set[string], path string) field.ErrorList { - allErrs := field.ErrorList{} - if len(splits) < 2 { - return append(allErrs, field.Invalid(fieldPath, "", "must include at least 2 splits")) + return field.ErrorList{field.Invalid(fieldPath, "", "must include at least 2 splits")} } + allErrs := field.ErrorList{} totalWeight := 0 - for i, s := range splits { idxPath := fieldPath.Index(i) @@ -1297,12 +1206,11 @@ func (vsv *VirtualServerValidator) validateSplits(splits []v1.Split, fieldPath * // We support prefix-based NGINX locations, positive case-sensitive/insensitive regular expressions matches and exact matches. // More info http://nginx.org/en/docs/http/ngx_http_core_module.html#location func validateRoutePath(path string, fieldPath *field.Path) field.ErrorList { - allErrs := field.ErrorList{} - if path == "" { - return append(allErrs, field.Required(fieldPath, "")) + return field.ErrorList{field.Required(fieldPath, "")} } + allErrs := field.ErrorList{} if strings.HasPrefix(path, "~") { allErrs = append(allErrs, validateRegexPath(path, fieldPath)...) } else if strings.HasPrefix(path, "/") { @@ -1312,22 +1220,17 @@ func validateRoutePath(path string, fieldPath *field.Path) field.ErrorList { } else { allErrs = append(allErrs, field.Invalid(fieldPath, path, "must start with /, ~ or =")) } - return allErrs } func validateRegexPath(path string, fieldPath *field.Path) field.ErrorList { - allErrs := field.ErrorList{} - if _, err := regexp.Compile(path); err != nil { - return append(allErrs, field.Invalid(fieldPath, path, fmt.Sprintf("must be a valid regular expression: %v", err))) + return field.ErrorList{field.Invalid(fieldPath, path, fmt.Sprintf("must be a valid regular expression: %v", err))} } - if err := ValidateEscapedString(path, "*.jpg", "^/images/image_*.png$"); err != nil { - return append(allErrs, field.Invalid(fieldPath, path, err.Error())) + return field.ErrorList{field.Invalid(fieldPath, path, err.Error())} } - - return allErrs + return nil } const ( @@ -1338,18 +1241,14 @@ const ( var pathRegexp = regexp.MustCompile("^" + pathFmt + "$") func validatePath(path string, fieldPath *field.Path) field.ErrorList { - allErrs := field.ErrorList{} - if path == "" { - return append(allErrs, field.Required(fieldPath, "")) + return field.ErrorList{field.Required(fieldPath, "")} } - if !pathRegexp.MatchString(path) { msg := validation.RegexError(pathErrMsg, pathFmt, "/", "/path", "/path/subpath-123") - return append(allErrs, field.Invalid(fieldPath, path, msg)) + return field.ErrorList{field.Invalid(fieldPath, path, msg)} } - - return allErrs + return nil } const ( @@ -1360,18 +1259,14 @@ const ( var grpcRegexp = regexp.MustCompile("^" + grpcFmt + "$") func validateGrpcService(service string, fieldPath *field.Path) field.ErrorList { - allErrs := field.ErrorList{} - if service == "" { - return allErrs + return nil } - if !grpcRegexp.MatchString(service) { msg := validation.RegexError(grpcErrMsg, grpcFmt, "GrpcService", "GrpcService.MyService") - return append(allErrs, field.Invalid(fieldPath, service, msg)) + return field.ErrorList{field.Invalid(fieldPath, service, msg)} } - - return allErrs + return nil } func (vsv *VirtualServerValidator) validateMatch(match v1.Match, fieldPath *field.Path, upstreamNames sets.Set[string], path string) field.ErrorList { @@ -1491,17 +1386,13 @@ var validVariableNames = map[string]bool{ } func validateVariableName(name string, fieldPath *field.Path) field.ErrorList { - allErrs := field.ErrorList{} - if !strings.HasPrefix(name, "$") { - return append(allErrs, field.Invalid(fieldPath, name, "must start with `$`")) + return field.ErrorList{field.Invalid(fieldPath, name, "must start with `$`")} } - if _, exists := validVariableNames[name]; !exists { - return append(allErrs, field.Invalid(fieldPath, name, "is not allowed or is not an NGINX variable")) + return field.ErrorList{field.Invalid(fieldPath, name, "is not allowed or is not an NGINX variable")} } - - return allErrs + return nil } func isValidMatchValue(value string) []string { @@ -1527,9 +1418,7 @@ func (vsv *VirtualServerValidator) ValidateVirtualServerRouteForVirtualServer(vi func (vsv *VirtualServerValidator) validateVirtualServerRouteSpec(spec *v1.VirtualServerRouteSpec, fieldPath *field.Path, virtualServerHost string, vsPath string, namespace string, ) field.ErrorList { - allErrs := field.ErrorList{} - - allErrs = append(allErrs, validateVirtualServerRouteHost(spec.Host, virtualServerHost, fieldPath.Child("host"))...) + allErrs := validateVirtualServerRouteHost(spec.Host, virtualServerHost, fieldPath.Child("host")) upstreamErrs, upstreamNames := vsv.validateUpstreams(spec.Upstreams, fieldPath.Child("upstreams")) allErrs = append(allErrs, upstreamErrs...) @@ -1540,15 +1429,11 @@ func (vsv *VirtualServerValidator) validateVirtualServerRouteSpec(spec *v1.Virtu } func validateVirtualServerRouteHost(host string, virtualServerHost string, fieldPath *field.Path) field.ErrorList { - allErrs := field.ErrorList{} - - allErrs = append(allErrs, validateHost(host, fieldPath)...) - + allErrs := validateHost(host, fieldPath) if virtualServerHost != "" && host != virtualServerHost { msg := fmt.Sprintf("must be equal to '%s'", virtualServerHost) allErrs = append(allErrs, field.Invalid(fieldPath, host, msg)) } - return allErrs } @@ -1598,12 +1483,11 @@ func (vsv *VirtualServerValidator) validateVirtualServerRouteSubroutes(routes [] } func rejectPlusResourcesInOSS(upstream v1.Upstream, idxPath *field.Path, isPlus bool) field.ErrorList { - allErrs := field.ErrorList{} - if isPlus { - return allErrs + return nil } + allErrs := field.ErrorList{} if upstream.HealthCheck != nil { allErrs = append(allErrs, field.Forbidden(idxPath.Child("healthCheck"), "active health checks are only supported in NGINX Plus")) } @@ -1628,17 +1512,13 @@ func rejectPlusResourcesInOSS(upstream v1.Upstream, idxPath *field.Path, isPlus } func validateQueue(queue *v1.UpstreamQueue, fieldPath *field.Path) field.ErrorList { - allErrs := field.ErrorList{} - if queue == nil { - return allErrs + return nil } - - allErrs = append(allErrs, validateTime(queue.Timeout, fieldPath.Child("timeout"))...) + allErrs := validateTime(queue.Timeout, fieldPath.Child("timeout")) if queue.Size <= 0 { allErrs = append(allErrs, field.Required(fieldPath.Child("size"), "must be positive")) } - return allErrs }