diff --git a/deployments/common/policy-definition.yaml b/deployments/common/policy-definition.yaml index fad2e5552a..85d63ce0ab 100644 --- a/deployments/common/policy-definition.yaml +++ b/deployments/common/policy-definition.yaml @@ -35,10 +35,9 @@ spec: metadata: type: object spec: - description: 'PolicySpec is the spec of the Policy resource. The spec includes - multiple fields, where each field represents a different policy. Note: - currently we have only one policy -- AccessControl, but we will support - more in the future. Only one policy (field) is allowed.' + description: PolicySpec is the spec of the Policy resource. The spec includes + multiple fields, where each field represents a different policy. Only + one policy (field) is allowed. type: object properties: accessControl: @@ -54,3 +53,25 @@ spec: type: array items: type: string + rateLimit: + description: RateLimit defines a rate limit policy. + type: object + properties: + burst: + type: integer + delay: + type: integer + dryRun: + type: boolean + key: + type: string + logLevel: + type: string + noDelay: + type: boolean + rate: + type: string + rejectCode: + type: integer + zoneSize: + type: string diff --git a/deployments/helm-chart/crds/policy.yaml b/deployments/helm-chart/crds/policy.yaml index 0e5b9b9e4a..f5fda77842 100644 --- a/deployments/helm-chart/crds/policy.yaml +++ b/deployments/helm-chart/crds/policy.yaml @@ -37,10 +37,9 @@ spec: metadata: type: object spec: - description: 'PolicySpec is the spec of the Policy resource. The spec includes - multiple fields, where each field represents a different policy. Note: - currently we have only one policy -- AccessControl, but we will support - more in the future. Only one policy (field) is allowed.' + description: PolicySpec is the spec of the Policy resource. The spec includes + multiple fields, where each field represents a different policy. Only + one policy (field) is allowed. type: object properties: accessControl: @@ -56,3 +55,25 @@ spec: type: array items: type: string + rateLimit: + description: RateLimit defines a rate limit policy. + type: object + properties: + burst: + type: integer + delay: + type: integer + dryRun: + type: boolean + key: + type: string + logLevel: + type: string + noDelay: + type: boolean + rate: + type: string + rejectCode: + type: integer + zoneSize: + type: string diff --git a/docs-web/configuration/policy-resource.md b/docs-web/configuration/policy-resource.md index 2fae96774e..30114e48d1 100644 --- a/docs-web/configuration/policy-resource.md +++ b/docs-web/configuration/policy-resource.md @@ -1,6 +1,6 @@ # Policy Resource -The Policy resource allows you to configure features like authentication, rate-limiting, and WAF, which you can add to your [VirtualServer and VirtualServerRoute resources](/nginx-ingress-controller/configuration/virtualserver-and-virtualserverroute-resources/). In the initial release, we are introducing support for access control based on the client IP address. +The Policy resource allows you to configure features like access control and rate-limiting, which you can add to your [VirtualServer and VirtualServerRoute resources](/nginx-ingress-controller/configuration/virtualserver-and-virtualserverroute-resources/). The resource is implemented as a [Custom Resource](https://kubernetes.io/docs/concepts/extend-kubernetes/api-extension/custom-resources/). @@ -16,6 +16,8 @@ This document is the reference documentation for the Policy resource. An example - [Policy Specification](#policy-specification) - [AccessControl](#accesscontrol) - [AccessControl Merging Behavior](#accesscontrol-merging-behavior) + - [RateLimit](#ratelimit) + - [RateLimit Merging Behavior](#ratelimit-merging-behavior) - [Using Policy](#using-policy) - [Validation](#validation) - [Structural Validation](#structural-validation) @@ -50,9 +52,15 @@ spec: * - ``accessControl`` - The access control policy based on the client IP address. - `accessControl <#accesscontrol>`_ - - Yes + - No* + * - ``rateLimit`` + - The rate limit policy controls the rate of processing requests per a defined key. + - `rateLimit <#ratelimit>`_ + - No* ``` +\* A policy must include exactly one policy. + ### AccessControl The access control policy configures NGINX to deny or allow requests from clients with the specified IP addresses/subnets. @@ -109,6 +117,78 @@ Referencing both allow and deny policies, as shown in the example below, is not - name: allow-policy-two ``` +### RateLimit + +The rate limit policy configures NGINX to limit the processing rate of requests. + +For example, the following policy will limit all subsequent requests coming from a single IP address once a rate of 10 requests per second is exceeded: +```yaml +rateLimit: + rate: 10r/s + zoneSize: 10M + key: ${binary_remote_addr} +``` + +> Note: The feature is implemented using the NGINX [ngx_http_limit_req_module](https://nginx.org/en/docs/http/ngx_http_limit_req_module.html). + +```eval_rst +.. list-table:: + :header-rows: 1 + + * - Field + - Description + - Type + - Required + * - ``rate`` + - The rate of requests permitted. The rate is specified in requests per second (r/s) or requests per minute (r/m). + - ``string`` + - Yes + * - ``key`` + - The key to which the rate limit is applied. Can contain text, variables, or a combination of them. Variables must be surrounded by ``${}``. For example: ``${binary_remote_addr}``. Accepted variables are ``$binary_remote_addr``, ``$request_uri``, ``$url``, ``$http_``, ``$args``, ``$arg_``, ``$cookie_``. + - ``string`` + - Yes + * - ``zoneSize`` + - Size of the shared memory zone. Only positive values are allowed. Allowed suffixes are ``k`` or ``m``, if none are present ``k`` is assumed. + - ``string`` + - Yes + * - ``delay`` + - The delay parameter specifies a limit at which excessive requests become delayed. If not set all excessive requests are delayed. + - ``int`` + - No* + * - ``noDelay`` + - Disables the delaying of excessive requests while requests are being limited. Overrides ``delay`` if both are set. + - ``bool`` + - No* + * - ``burst`` + - Excessive requests are delayed until their number exceeds the ``burst`` size, in which case the request is terminated with an error. + - ``int`` + - No* + * - ``dryRun`` + - Enables the dry run mode. In this mode, the rate limit is not actually applied, but the the number of excessive requests is accounted as usual in the shared memory zone. + - ``bool`` + - No* + * - ``logLevel`` + - Sets the desired logging level for cases when the server refuses to process requests due to rate exceeding, or delays request processing. Allowed values are ``info``, ``notice``, ``warn`` or ``error``. Default is ``error``. + - ``string`` + - No* + * - ``rejectCode`` + - Sets the status code to return in response to rejected requests. Must fall into the range ``400..599``. Default is ``503``. + - ``string`` + - No* +``` + +> For each policy referenced in a VirtualServer and/or its VirtualServerRoutes, the Ingress Controller will generate a single rate limiting zone defined by the [`limit_req_zone`](http://nginx.org/en/docs/http/ngx_http_limit_req_module.html#limit_req_zone) directive. If two VirtualServer resources reference the same policy, the Ingress Controller will generate two different rate limiting zones, one zone per VirtualServer. + +#### RateLimit Merging Behavior +A VirtualServer/VirtualServerRoute can reference multiple rate limit policies. For example, here we reference two policies: +```yaml +policies: +- name: rate-limit-policy-one +- name: rate-limit-policy-two +``` + +When you reference more than one rate limit policy, the Ingress Controller will configure NGINX to use all referenced rate limits. When you define multiple policies, each additional policy inherits the `dryRun`, `logLevel`, and `rejectCode` parameters from the first policy referenced (`rate-limit-policy-one`, in the example above). + ## Using Policy You can use the usual `kubectl` commands to work with Policy resources, just as with built-in Kubernetes resources. diff --git a/examples-of-custom-resources/rate-limit/README.md b/examples-of-custom-resources/rate-limit/README.md new file mode 100644 index 0000000000..eb4ce171cd --- /dev/null +++ b/examples-of-custom-resources/rate-limit/README.md @@ -0,0 +1,61 @@ +# Rate Limit + +In this example, we deploy a web application, configure load balancing for it via a VirtualServer, and apply a rate limit policy. + +## Prerequisites + +1. Follow the [installation](https://docs.nginx.com/nginx-ingress-controller/installation/installation-with-manifests/) instructions to deploy the Ingress Controller. +1. Save the public IP address of the Ingress Controller into a shell variable: + ``` + $ IC_IP=XXX.YYY.ZZZ.III + ``` +1. Save the HTTP port of the Ingress Controller into a shell variable: + ``` + $ IC_HTTP_PORT= + ``` + +## Step 1 - Deploy a Web Application + +Create the application deployment and service: +``` +$ kubectl apply -f webapp.yaml +``` + +## Step 2 - Deploy the Rate Limit Policy + +In this step, we create a policy with the name `rate-limit-policy` that allows only 1 request per second coming from a single IP address. + +Create the policy: +``` +$ kubectl apply -f rate-limit.yaml +``` + +## Step 3 - Configure Load Balancing + +Create a VirtualServer resource for the web application: +``` +$ kubectl apply -f virtual-server.yaml +``` + +Note that the VirtualServer references the policy `rate-limit-policy` created in Step 2. + +## Step 4 - Test the Configuration + +Let's test the configuration. If you access the application at a rate that exceeds one request per second, NGINX will start rejecting your requests: +``` +$ curl --resolve webapp.example.com:$IC_HTTP_PORT:$IC_IP http://webapp.example.com:$IC_HTTP_PORT/ +Server address: 10.8.1.19:8080 +Server name: webapp-dc88fc766-zr7f8 +. . . + +$ curl --resolve webapp.example.com:$IC_HTTP_PORT:$IC_IP http://webapp.example.com:$IC_HTTP_PORT/ + +503 Service Temporarily Unavailable + +

503 Service Temporarily Unavailable

+
nginx/1.19.1
+ + +``` + +> Note: The command result is truncated for the clarity of the example. diff --git a/examples-of-custom-resources/rate-limit/rate-limit.yaml b/examples-of-custom-resources/rate-limit/rate-limit.yaml new file mode 100644 index 0000000000..90f33a7337 --- /dev/null +++ b/examples-of-custom-resources/rate-limit/rate-limit.yaml @@ -0,0 +1,9 @@ +apiVersion: k8s.nginx.org/v1alpha1 +kind: Policy +metadata: + name: rate-limit-policy +spec: + rateLimit: + rate: 1r/s + key: ${binary_remote_addr} + zoneSize: 10M diff --git a/examples-of-custom-resources/rate-limit/virtual-server.yaml b/examples-of-custom-resources/rate-limit/virtual-server.yaml new file mode 100644 index 0000000000..59ecbda25e --- /dev/null +++ b/examples-of-custom-resources/rate-limit/virtual-server.yaml @@ -0,0 +1,16 @@ +apiVersion: k8s.nginx.org/v1 +kind: VirtualServer +metadata: + name: webapp +spec: + host: webapp.example.com + policies: + - name: rate-limit-policy + upstreams: + - name: webapp + service: webapp-svc + port: 80 + routes: + - path: / + action: + pass: webapp diff --git a/examples-of-custom-resources/rate-limit/webapp.yaml b/examples-of-custom-resources/rate-limit/webapp.yaml new file mode 100644 index 0000000000..67556cf616 --- /dev/null +++ b/examples-of-custom-resources/rate-limit/webapp.yaml @@ -0,0 +1,32 @@ +apiVersion: apps/v1 +kind: Deployment +metadata: + name: webapp +spec: + replicas: 1 + selector: + matchLabels: + app: webapp + template: + metadata: + labels: + app: webapp + spec: + containers: + - name: webapp + image: nginxdemos/nginx-hello:plain-text + ports: + - containerPort: 8080 +--- +apiVersion: v1 +kind: Service +metadata: + name: webapp-svc +spec: + ports: + - port: 80 + targetPort: 8080 + protocol: TCP + name: http + selector: + app: webapp \ No newline at end of file diff --git a/internal/configs/version2/http.go b/internal/configs/version2/http.go index 3736543439..9f396b9510 100644 --- a/internal/configs/version2/http.go +++ b/internal/configs/version2/http.go @@ -1,5 +1,7 @@ package version2 +import "fmt" + // UpstreamLabels describes the Prometheus labels for an NGINX upstream. type UpstreamLabels struct { Service string @@ -15,6 +17,7 @@ type VirtualServerConfig struct { SplitClients []SplitClient Maps []Map StatusMatches []StatusMatch + LimitReqZones []LimitReqZone HTTPSnippets []string SpiffeCerts bool } @@ -61,6 +64,8 @@ type Server struct { TLSPassthrough bool Allow []string Deny []string + LimitReqOptions LimitReqOptions + LimitReqs []LimitReq PoliciesErrorReturn *Return } @@ -105,6 +110,8 @@ type Location struct { Allow []string Deny []string PoliciesErrorReturn *Return + LimitReqOptions LimitReqOptions + LimitReqs []LimitReq } // ReturnLocation defines a location for returning a fixed response. @@ -224,3 +231,38 @@ type Queue struct { Size int Timeout string } + +// LimitReqZone defines a rate limit shared memory zone. +type LimitReqZone struct { + Key string + ZoneName string + ZoneSize string + Rate string +} + +func (rlz LimitReqZone) String() string { + return fmt.Sprintf("{Key %q, ZoneName %q, ZoneSize %v, Rate %q}", rlz.Key, rlz.ZoneName, rlz.ZoneSize, rlz.Rate) +} + +// LimitReq defines a rate limit. +type LimitReq struct { + ZoneName string + Burst int + NoDelay bool + Delay int +} + +func (rl LimitReq) String() string { + return fmt.Sprintf("{ZoneName %q, Burst %q, NoDelay %v, Delay %q}", rl.ZoneName, rl.Burst, rl.NoDelay, rl.Delay) +} + +// LimitReqOptions defines rate limit options. +type LimitReqOptions struct { + DryRun bool + LogLevel string + RejectCode int +} + +func (rl LimitReqOptions) String() string { + return fmt.Sprintf("{DryRun %v, LogLevel %q, RejectCode %q}", rl.DryRun, rl.LogLevel, rl.RejectCode) +} diff --git a/internal/configs/version2/nginx-plus.virtualserver.tmpl b/internal/configs/version2/nginx-plus.virtualserver.tmpl index 50bfa45f8b..359e6c7d46 100644 --- a/internal/configs/version2/nginx-plus.virtualserver.tmpl +++ b/internal/configs/version2/nginx-plus.virtualserver.tmpl @@ -44,6 +44,10 @@ map {{ $m.Source }} {{ $m.Variable }} { {{- $snippet }} {{ end }} +{{ range $z := .LimitReqZones }} +limit_req_zone {{ $z.Key }} zone={{ $z.ZoneName }}:{{ $z.ZoneSize }} rate={{ $z.Rate }}; +{{ end }} + {{ range $m := .StatusMatches }} match {{ $m.Name }} { status {{ $m.Code }}; @@ -110,6 +114,23 @@ server { allow all; {{ end }} + {{ if $s.LimitReqOptions.DryRun }} + limit_req_dry_run on; + {{ end }} + + {{ with $level := $s.LimitReqOptions.LogLevel }} + limit_req_log_level {{ $level }}; + {{ end }} + + {{ with $code := $s.LimitReqOptions.RejectCode }} + limit_req_status {{ $code }}; + {{ end }} + + {{ range $rl := $s.LimitReqs }} + limit_req zone={{ $rl.ZoneName }}{{ if $rl.Burst }} burst={{ $rl.Burst }}{{ end }} + {{ if $rl.Delay }} delay={{ $rl.Delay }}{{ end }}{{ if $rl.NoDelay }} nodelay{{ end }}; + {{ end }} + {{ range $snippet := $s.Snippets }} {{- $snippet }} {{ end }} @@ -182,6 +203,23 @@ server { allow all; {{ end }} + {{ if $l.LimitReqOptions.DryRun }} + limit_req_dry_run on; + {{ end }} + + {{ with $level := $l.LimitReqOptions.LogLevel }} + limit_req_log_level {{ $level }}; + {{ end }} + + {{ with $code := $l.LimitReqOptions.RejectCode }} + limit_req_status {{ $code }}; + {{ end }} + + {{ range $rl := $l.LimitReqs }} + limit_req zone={{ $rl.ZoneName }}{{ if $rl.Burst }} burst={{ $rl.Burst }}{{ end }} + {{ if $rl.Delay }} delay={{ $rl.Delay }}{{ end }}{{ if $rl.NoDelay }} nodelay{{ end }}; + {{ end }} + {{ range $e := $l.ErrorPages }} error_page {{ $e.Codes }} {{ if ne 0 $e.ResponseCode }}={{ $e.ResponseCode }}{{ end }} "{{ $e.Name }}"; {{ end }} diff --git a/internal/configs/version2/nginx.virtualserver.tmpl b/internal/configs/version2/nginx.virtualserver.tmpl index b246ccadcf..d90ac77341 100644 --- a/internal/configs/version2/nginx.virtualserver.tmpl +++ b/internal/configs/version2/nginx.virtualserver.tmpl @@ -34,6 +34,10 @@ map {{ $m.Source }} {{ $m.Variable }} { {{- $snippet }} {{ end }} +{{ range $z := .LimitReqZones }} +limit_req_zone {{ $z.Key }} zone={{ $z.ZoneName }}:{{ $z.ZoneSize }} rate={{ $z.Rate }}; +{{ end }} + {{ $s := .Server }} server { listen 80{{ if $s.ProxyProtocol }} proxy_protocol{{ end }}; @@ -93,6 +97,23 @@ server { allow all; {{ end }} + {{ if $s.LimitReqOptions.DryRun }} + limit_req_dry_run on; + {{ end }} + + {{ with $level := $s.LimitReqOptions.LogLevel }} + limit_req_log_level {{ $level }}; + {{ end }} + + {{ with $code := $s.LimitReqOptions.RejectCode }} + limit_req_status {{ $code }}; + {{ end }} + + {{ range $rl := $s.LimitReqs }} + limit_req zone={{ $rl.ZoneName }}{{ if $rl.Burst }} burst={{ $rl.Burst }}{{ end }} + {{ if $rl.Delay }} delay={{ $rl.Delay }}{{ end }}{{ if $rl.NoDelay }} nodelay{{ end }}; + {{ end }} + {{ range $snippet := $s.Snippets }} {{- $snippet }} {{ end }} @@ -151,6 +172,23 @@ server { allow all; {{ end }} + {{ if $l.LimitReqOptions.DryRun }} + limit_req_dry_run on; + {{ end }} + + {{ with $level := $l.LimitReqOptions.LogLevel }} + limit_req_log_level {{ $level }}; + {{ end }} + + {{ with $code := $l.LimitReqOptions.RejectCode }} + limit_req_status {{ $code }}; + {{ end }} + + {{ range $rl := $l.LimitReqs }} + limit_req zone={{ $rl.ZoneName }}{{ if $rl.Burst }} burst={{ $rl.Burst }}{{ end }} + {{ if $rl.Delay }} delay={{ $rl.Delay }}{{ end }}{{ if $rl.NoDelay }} nodelay{{ end }}; + {{ end }} + {{ range $e := $l.ErrorPages }} error_page {{ $e.Codes }} {{ if ne 0 $e.ResponseCode }}={{ $e.ResponseCode }}{{ end }} "{{ $e.Name }}"; {{ end }} diff --git a/internal/configs/version2/templates_test.go b/internal/configs/version2/templates_test.go index 735e15b43f..abd84fb977 100644 --- a/internal/configs/version2/templates_test.go +++ b/internal/configs/version2/templates_test.go @@ -10,6 +10,11 @@ const nginxPlusTransportServerTmpl = "nginx-plus.transportserver.tmpl" const nginxTransportServerTmpl = "nginx.transportserver.tmpl" var virtualServerCfg = VirtualServerConfig{ + LimitReqZones: []LimitReqZone{ + { + ZoneName: "pol_rl_test_test_test", Rate: "10r/s", ZoneSize: "10m", Key: "$url", + }, + }, Upstreams: []Upstream{ { Name: "test-upstream", @@ -120,7 +125,18 @@ var virtualServerCfg = VirtualServerConfig{ RealIPRecursive: true, Allow: []string{"127.0.0.1"}, Deny: []string{"127.0.0.1"}, - Snippets: []string{"# server snippet"}, + LimitReqs: []LimitReq{ + { + ZoneName: "pol_rl_test_test_test", + Delay: 10, + Burst: 5, + }, + }, + LimitReqOptions: LimitReqOptions{ + LogLevel: "error", + RejectCode: 503, + }, + Snippets: []string{"# server snippet"}, InternalRedirectLocations: []InternalRedirectLocation{ { Path: "/split", @@ -133,10 +149,15 @@ var virtualServerCfg = VirtualServerConfig{ }, Locations: []Location{ { - Path: "/", - Snippets: []string{"# location snippet"}, - Allow: []string{"127.0.0.1"}, - Deny: []string{"127.0.0.1"}, + Path: "/", + Snippets: []string{"# location snippet"}, + Allow: []string{"127.0.0.1"}, + Deny: []string{"127.0.0.1"}, + LimitReqs: []LimitReq{ + { + ZoneName: "loc_pol_rl_test_test_test", + }, + }, ProxyConnectTimeout: "30s", ProxyReadTimeout: "31s", ProxySendTimeout: "32s", diff --git a/internal/configs/virtualserver.go b/internal/configs/virtualserver.go index 7e2fd09311..82e3a73e38 100644 --- a/internal/configs/virtualserver.go +++ b/internal/configs/virtualserver.go @@ -191,30 +191,33 @@ func (vsc *virtualServerConfigurator) generateEndpointsForUpstream(owner runtime } // GenerateVirtualServerConfig generates a full configuration for a VirtualServer -func (vsc *virtualServerConfigurator) GenerateVirtualServerConfig(virtualServerEx *VirtualServerEx, tlsPemFileName string) (version2.VirtualServerConfig, Warnings) { +func (vsc *virtualServerConfigurator) GenerateVirtualServerConfig(vsEx *VirtualServerEx, tlsPemFileName string) (version2.VirtualServerConfig, Warnings) { vsc.clearWarnings() - policiesCfg := vsc.generatePolicies(virtualServerEx.VirtualServer, virtualServerEx.VirtualServer.Namespace, - virtualServerEx.VirtualServer.Spec.Policies, virtualServerEx.Policies) + policiesCfg := vsc.generatePolicies(vsEx.VirtualServer, vsEx.VirtualServer.Namespace, vsEx.VirtualServer.Namespace, + vsEx.VirtualServer.Name, vsEx.VirtualServer.Spec.Policies, vsEx.Policies) // crUpstreams maps an UpstreamName to its conf_v1.Upstream as they are generated // necessary for generateLocation to know what Upstream each Location references crUpstreams := make(map[string]conf_v1.Upstream) - virtualServerUpstreamNamer := newUpstreamNamerForVirtualServer(virtualServerEx.VirtualServer) + virtualServerUpstreamNamer := newUpstreamNamerForVirtualServer(vsEx.VirtualServer) var upstreams []version2.Upstream var statusMatches []version2.StatusMatch var healthChecks []version2.HealthCheck + var limitReqZones []version2.LimitReqZone + + limitReqZones = append(limitReqZones, policiesCfg.LimitReqZones...) // generate upstreams for VirtualServer - for _, u := range virtualServerEx.VirtualServer.Spec.Upstreams { + for _, u := range vsEx.VirtualServer.Spec.Upstreams { upstreamName := virtualServerUpstreamNamer.GetNameForUpstream(u.Name) - upstreamNamespace := virtualServerEx.VirtualServer.Namespace - endpoints := vsc.generateEndpointsForUpstream(virtualServerEx.VirtualServer, upstreamNamespace, u, virtualServerEx) + upstreamNamespace := vsEx.VirtualServer.Namespace + endpoints := vsc.generateEndpointsForUpstream(vsEx.VirtualServer, upstreamNamespace, u, vsEx) // isExternalNameSvc is always false for OSS - _, isExternalNameSvc := virtualServerEx.ExternalNameSvcs[GenerateExternalNameSvcKey(upstreamNamespace, u.Service)] - ups := vsc.generateUpstream(virtualServerEx.VirtualServer, upstreamName, u, isExternalNameSvc, endpoints) + _, isExternalNameSvc := vsEx.ExternalNameSvcs[GenerateExternalNameSvcKey(upstreamNamespace, u.Service)] + ups := vsc.generateUpstream(vsEx.VirtualServer, upstreamName, u, isExternalNameSvc, endpoints) upstreams = append(upstreams, ups) u.TLS.Enable = isTLSEnabled(u, vsc.spiffeCerts) @@ -228,15 +231,15 @@ func (vsc *virtualServerConfigurator) GenerateVirtualServerConfig(virtualServerE } } // generate upstreams for each VirtualServerRoute - for _, vsr := range virtualServerEx.VirtualServerRoutes { - upstreamNamer := newUpstreamNamerForVirtualServerRoute(virtualServerEx.VirtualServer, vsr) + for _, vsr := range vsEx.VirtualServerRoutes { + upstreamNamer := newUpstreamNamerForVirtualServerRoute(vsEx.VirtualServer, vsr) for _, u := range vsr.Spec.Upstreams { upstreamName := upstreamNamer.GetNameForUpstream(u.Name) upstreamNamespace := vsr.Namespace - endpoints := vsc.generateEndpointsForUpstream(vsr, upstreamNamespace, u, virtualServerEx) + endpoints := vsc.generateEndpointsForUpstream(vsr, upstreamNamespace, u, vsEx) // isExternalNameSvc is always false for OSS - _, isExternalNameSvc := virtualServerEx.ExternalNameSvcs[GenerateExternalNameSvcKey(upstreamNamespace, u.Service)] + _, isExternalNameSvc := vsEx.ExternalNameSvcs[GenerateExternalNameSvcKey(upstreamNamespace, u.Service)] ups := vsc.generateUpstream(vsr, upstreamName, u, isExternalNameSvc, endpoints) upstreams = append(upstreams, ups) u.TLS.Enable = isTLSEnabled(u, vsc.spiffeCerts) @@ -263,10 +266,10 @@ func (vsc *virtualServerConfigurator) GenerateVirtualServerConfig(virtualServerE var vsrPoliciesFromVs = make(map[string][]conf_v1.PolicyReference) matchesRoutes := 0 - variableNamer := newVariableNamer(virtualServerEx.VirtualServer) + variableNamer := newVariableNamer(vsEx.VirtualServer) // generates config for VirtualServer routes - for _, r := range virtualServerEx.VirtualServer.Spec.Routes { + for _, r := range vsEx.VirtualServer.Spec.Routes { errorPageIndex := len(errorPageLocations) errorPageLocations = append(errorPageLocations, generateErrorPageLocations(errorPageIndex, r.ErrorPages)...) @@ -274,7 +277,7 @@ func (vsc *virtualServerConfigurator) GenerateVirtualServerConfig(virtualServerE if r.Route != "" { name := r.Route if !strings.Contains(name, "/") { - name = fmt.Sprintf("%v/%v", virtualServerEx.VirtualServer.Namespace, r.Route) + name = fmt.Sprintf("%v/%v", vsEx.VirtualServer.Namespace, r.Route) } // store route location snippet for the referenced VirtualServerRoute in case they don't define their own @@ -297,8 +300,9 @@ func (vsc *virtualServerConfigurator) GenerateVirtualServerConfig(virtualServerE } vsLocSnippets := r.LocationSnippets - routePoliciesCfg := vsc.generatePolicies(virtualServerEx.VirtualServer, virtualServerEx.VirtualServer.Namespace, - r.Policies, virtualServerEx.Policies) + routePoliciesCfg := vsc.generatePolicies(vsEx.VirtualServer, vsEx.VirtualServer.Namespace, vsEx.VirtualServer.Namespace, vsEx.VirtualServer.Name, + r.Policies, vsEx.Policies) + limitReqZones = append(limitReqZones, routePoliciesCfg.LimitReqZones...) if len(r.Matches) > 0 { cfg := generateMatchesConfig(r, virtualServerUpstreamNamer, crUpstreams, variableNamer, matchesRoutes, len(splitClients), @@ -323,7 +327,7 @@ func (vsc *virtualServerConfigurator) GenerateVirtualServerConfig(virtualServerE } else { upstreamName := virtualServerUpstreamNamer.GetNameForUpstreamFromAction(r.Action) upstream := crUpstreams[upstreamName] - proxySSLName := generateProxySSLName(upstream.Service, virtualServerEx.VirtualServer.Namespace) + proxySSLName := generateProxySSLName(upstream.Service, vsEx.VirtualServer.Namespace) loc, returnLoc := generateLocation(r.Path, upstreamName, upstream, r.Action, vsc.cfgParams, r.ErrorPages, false, errorPageIndex, proxySSLName, r.Path, vsLocSnippets, vsc.enableSnippets, len(returnLocations)) @@ -337,8 +341,8 @@ func (vsc *virtualServerConfigurator) GenerateVirtualServerConfig(virtualServerE } // generate config for subroutes of each VirtualServerRoute - for _, vsr := range virtualServerEx.VirtualServerRoutes { - upstreamNamer := newUpstreamNamerForVirtualServerRoute(virtualServerEx.VirtualServer, vsr) + for _, vsr := range vsEx.VirtualServerRoutes { + upstreamNamer := newUpstreamNamerForVirtualServerRoute(vsEx.VirtualServer, vsr) for _, r := range vsr.Spec.Subroutes { errorPageIndex := len(errorPageLocations) errorPageLocations = append(errorPageLocations, generateErrorPageLocations(errorPageIndex, r.ErrorPages)...) @@ -358,12 +362,14 @@ func (vsc *virtualServerConfigurator) GenerateVirtualServerConfig(virtualServerE locSnippets = vsrLocationSnippetsFromVs[vsrNamespaceName] } - routePoliciesCfg := vsc.generatePolicies(vsr, vsr.Namespace, r.Policies, virtualServerEx.Policies) + routePoliciesCfg := vsc.generatePolicies(vsr, vsr.Namespace, vsEx.VirtualServer.Namespace, vsEx.VirtualServer.Name, + r.Policies, vsEx.Policies) // use the VirtualServer route policies if the route does not define any if len(r.Policies) == 0 { - routePoliciesCfg = vsc.generatePolicies(virtualServerEx.VirtualServer, virtualServerEx.VirtualServer.Namespace, - vsrPoliciesFromVs[vsrNamespaceName], virtualServerEx.Policies) + routePoliciesCfg = vsc.generatePolicies(vsEx.VirtualServer, vsEx.VirtualServer.Namespace, vsEx.VirtualServer.Namespace, + vsEx.VirtualServer.Name, vsrPoliciesFromVs[vsrNamespaceName], vsEx.Policies) } + limitReqZones = append(limitReqZones, routePoliciesCfg.LimitReqZones...) if len(r.Matches) > 0 { cfg := generateMatchesConfig(r, upstreamNamer, crUpstreams, variableNamer, matchesRoutes, len(splitClients), @@ -402,20 +408,21 @@ func (vsc *virtualServerConfigurator) GenerateVirtualServerConfig(virtualServerE } } - ssl := generateSSLConfig(virtualServerEx.VirtualServer.Spec.TLS, tlsPemFileName, vsc.cfgParams) - tlsRedirectConfig := generateTLSRedirectConfig(virtualServerEx.VirtualServer.Spec.TLS) - httpSnippets := generateSnippets(vsc.enableSnippets, virtualServerEx.VirtualServer.Spec.HTTPSnippets, []string{""}) - serverSnippets := generateSnippets(vsc.enableSnippets, virtualServerEx.VirtualServer.Spec.ServerSnippets, vsc.cfgParams.ServerSnippets) + ssl := generateSSLConfig(vsEx.VirtualServer.Spec.TLS, tlsPemFileName, vsc.cfgParams) + tlsRedirectConfig := generateTLSRedirectConfig(vsEx.VirtualServer.Spec.TLS) + httpSnippets := generateSnippets(vsc.enableSnippets, vsEx.VirtualServer.Spec.HTTPSnippets, []string{""}) + serverSnippets := generateSnippets(vsc.enableSnippets, vsEx.VirtualServer.Spec.ServerSnippets, vsc.cfgParams.ServerSnippets) vsCfg := version2.VirtualServerConfig{ Upstreams: upstreams, SplitClients: splitClients, Maps: maps, StatusMatches: statusMatches, + LimitReqZones: removeDuplicateLimitReqZones(limitReqZones), HTTPSnippets: httpSnippets, Server: version2.Server{ - ServerName: virtualServerEx.VirtualServer.Spec.Host, - StatusZone: virtualServerEx.VirtualServer.Spec.Host, + ServerName: vsEx.VirtualServer.Spec.Host, + StatusZone: vsEx.VirtualServer.Spec.Host, ProxyProtocol: vsc.cfgParams.ProxyProtocol, SSL: ssl, ServerTokens: vsc.cfgParams.ServerTokens, @@ -432,6 +439,8 @@ func (vsc *virtualServerConfigurator) GenerateVirtualServerConfig(virtualServerE TLSPassthrough: vsc.isTLSPassthrough, Allow: policiesCfg.Allow, Deny: policiesCfg.Deny, + LimitReqOptions: policiesCfg.LimitReqOptions, + LimitReqs: policiesCfg.LimitReqs, PoliciesErrorReturn: policiesCfg.ErrorReturn, }, SpiffeCerts: vsc.spiffeCerts, @@ -441,15 +450,21 @@ func (vsc *virtualServerConfigurator) GenerateVirtualServerConfig(virtualServerE } type policiesCfg struct { - Allow []string - Deny []string - ErrorReturn *version2.Return + Allow []string + Deny []string + LimitReqOptions version2.LimitReqOptions + LimitReqZones []version2.LimitReqZone + LimitReqs []version2.LimitReq + ErrorReturn *version2.Return } -func (vsc *virtualServerConfigurator) generatePolicies(owner runtime.Object, ownerNamespace string, policyRefs []conf_v1.PolicyReference, - policies map[string]*conf_v1alpha1.Policy) policiesCfg { +func (vsc *virtualServerConfigurator) generatePolicies(owner runtime.Object, ownerNamespace string, vsNamespace string, + vsName string, policyRefs []conf_v1.PolicyReference, policies map[string]*conf_v1alpha1.Policy) policiesCfg { var policyErrorReturn *version2.Return var allow, deny []string + var limitReqOptions version2.LimitReqOptions + var limitReqZones []version2.LimitReqZone + var limitReqs []version2.LimitReq var policyError bool for _, p := range policyRefs { @@ -465,6 +480,28 @@ func (vsc *virtualServerConfigurator) generatePolicies(owner runtime.Object, own allow = append(allow, pol.Spec.AccessControl.Allow...) deny = append(deny, pol.Spec.AccessControl.Deny...) } + if pol.Spec.RateLimit != nil { + rlZoneName := fmt.Sprintf("pol_rl_%v_%v_%v_%v", polNamespace, p.Name, vsNamespace, vsName) + limitReqs = append(limitReqs, generateLimitReq(rlZoneName, pol.Spec.RateLimit)) + limitReqZones = append(limitReqZones, generateLimitReqZone(rlZoneName, pol.Spec.RateLimit)) + if len(limitReqs) == 1 { + limitReqOptions = generateLimitReqOptions(pol.Spec.RateLimit) + } else { + curOptions := generateLimitReqOptions(pol.Spec.RateLimit) + if curOptions.DryRun != limitReqOptions.DryRun { + vsc.addWarningf(owner, "RateLimit policy %v with limit request option dryRun=%v is overridden to dryRun=%v by the first policy reference in this context", + key, curOptions.DryRun, limitReqOptions.DryRun) + } + if curOptions.LogLevel != limitReqOptions.LogLevel { + vsc.addWarningf(owner, "RateLimit policy %v with limit request option logLevel=%v is overridden to logLevel=%v by the first policy reference in this context", + key, curOptions.LogLevel, limitReqOptions.LogLevel) + } + if curOptions.RejectCode != limitReqOptions.RejectCode { + vsc.addWarningf(owner, "RateLimit policy %v with limit request option rejectCode=%v is overridden to rejectCode=%v by the first policy reference in this context", + key, curOptions.RejectCode, limitReqOptions.RejectCode) + } + } + } } else { vsc.addWarningf(owner, "Policy %s is missing or invalid", key) policyError = true @@ -472,23 +509,79 @@ func (vsc *virtualServerConfigurator) generatePolicies(owner runtime.Object, own } } if policyError { - allow = []string{} - deny = []string{} - policyErrorReturn = &version2.Return{Code: 500} + return policiesCfg{ + ErrorReturn: &version2.Return{Code: 500}, + } } else if len(allow) > 0 && len(deny) > 0 { vsc.addWarningf(owner, "AccessControl policy (or policies) with deny rules is overridden by policy (or policies) with allow rules") } return policiesCfg{ - Allow: allow, - Deny: deny, - ErrorReturn: policyErrorReturn, + Allow: allow, + Deny: deny, + LimitReqOptions: limitReqOptions, + LimitReqZones: limitReqZones, + LimitReqs: limitReqs, + ErrorReturn: policyErrorReturn, } } +func generateLimitReq(zoneName string, rateLimitPol *conf_v1alpha1.RateLimit) version2.LimitReq { + var limitReq version2.LimitReq + + limitReq.ZoneName = zoneName + + if rateLimitPol.Burst != nil { + limitReq.Burst = *rateLimitPol.Burst + } + if rateLimitPol.Delay != nil { + limitReq.Delay = *rateLimitPol.Delay + } + + limitReq.NoDelay = generateBool(rateLimitPol.NoDelay, false) + if limitReq.NoDelay { + limitReq.Delay = 0 + } + + return limitReq +} + +func generateLimitReqZone(zoneName string, rateLimitPol *conf_v1alpha1.RateLimit) version2.LimitReqZone { + return version2.LimitReqZone{ + ZoneName: zoneName, + Key: rateLimitPol.Key, + ZoneSize: rateLimitPol.ZoneSize, + Rate: rateLimitPol.Rate, + } +} + +func generateLimitReqOptions(rateLimitPol *conf_v1alpha1.RateLimit) version2.LimitReqOptions { + return version2.LimitReqOptions{ + DryRun: generateBool(rateLimitPol.DryRun, false), + LogLevel: generateString(rateLimitPol.LogLevel, "error"), + RejectCode: generateIntFromPointer(rateLimitPol.RejectCode, 503), + } +} + +func removeDuplicateLimitReqZones(rlz []version2.LimitReqZone) []version2.LimitReqZone { + encountered := make(map[string]bool) + result := []version2.LimitReqZone{} + + for _, v := range rlz { + if !encountered[v.ZoneName] { + encountered[v.ZoneName] = true + result = append(result, v) + } + } + + return result +} + func addPoliciesCfgToLocation(cfg policiesCfg, location *version2.Location) { location.Allow = cfg.Allow location.Deny = cfg.Deny + location.LimitReqOptions = cfg.LimitReqOptions + location.LimitReqs = cfg.LimitReqs location.PoliciesErrorReturn = cfg.ErrorReturn } @@ -527,7 +620,6 @@ func (vsc *virtualServerConfigurator) generateUpstream(owner runtime.Object, ups s := version2.UpstreamServer{ Address: e, } - upsServers = append(upsServers, s) } diff --git a/internal/configs/virtualserver_test.go b/internal/configs/virtualserver_test.go index 2d5929ae8c..00c23a6a72 100644 --- a/internal/configs/virtualserver_test.go +++ b/internal/configs/virtualserver_test.go @@ -471,7 +471,8 @@ func TestGenerateVirtualServerConfig(t *testing.T) { Keepalive: 16, }, }, - HTTPSnippets: []string{""}, + HTTPSnippets: []string{""}, + LimitReqZones: []version2.LimitReqZone{}, Server: version2.Server{ ServerName: "cafe.example.com", StatusZone: "cafe.example.com", @@ -667,7 +668,8 @@ func TestGenerateVirtualServerConfigWithSpiffeCerts(t *testing.T) { Keepalive: 16, }, }, - HTTPSnippets: []string{""}, + HTTPSnippets: []string{""}, + LimitReqZones: []version2.LimitReqZone{}, Server: version2.Server{ ServerName: "cafe.example.com", StatusZone: "cafe.example.com", @@ -903,7 +905,8 @@ func TestGenerateVirtualServerConfigForVirtualServerWithSplits(t *testing.T) { }, }, }, - HTTPSnippets: []string{""}, + HTTPSnippets: []string{""}, + LimitReqZones: []version2.LimitReqZone{}, Server: version2.Server{ ServerName: "cafe.example.com", StatusZone: "cafe.example.com", @@ -1203,7 +1206,8 @@ func TestGenerateVirtualServerConfigForVirtualServerWithMatches(t *testing.T) { }, }, }, - HTTPSnippets: []string{""}, + HTTPSnippets: []string{""}, + LimitReqZones: []version2.LimitReqZone{}, Server: version2.Server{ ServerName: "cafe.example.com", StatusZone: "cafe.example.com", @@ -1503,7 +1507,8 @@ func TestGenerateVirtualServerConfigForVirtualServerWithReturns(t *testing.T) { }, }, }, - HTTPSnippets: []string{""}, + HTTPSnippets: []string{""}, + LimitReqZones: []version2.LimitReqZone{}, Server: version2.Server{ ServerName: "example.com", StatusZone: "example.com", @@ -1749,6 +1754,8 @@ func TestGenerateVirtualServerConfigForVirtualServerWithReturns(t *testing.T) { func TestGeneratePolicies(t *testing.T) { var owner runtime.Object // nil is OK for the unit test ownerNamespace := "default" + vsNamespace := "default" + vsName := "test" tests := []struct { policyRefs []conf_v1.PolicyReference @@ -1827,12 +1834,113 @@ func TestGeneratePolicies(t *testing.T) { }, msg: "merging", }, + { + policyRefs: []conf_v1.PolicyReference{ + { + Name: "rateLimit-policy", + Namespace: "default", + }, + }, + policies: map[string]*conf_v1alpha1.Policy{ + "default/rateLimit-policy": { + Spec: conf_v1alpha1.PolicySpec{ + RateLimit: &conf_v1alpha1.RateLimit{ + Key: "test", + ZoneSize: "10M", + Rate: "10r/s", + LogLevel: "notice", + }, + }, + }, + }, + expected: policiesCfg{ + LimitReqZones: []version2.LimitReqZone{ + { + Key: "test", + ZoneSize: "10M", + Rate: "10r/s", + ZoneName: "pol_rl_default_rateLimit-policy_default_test", + }, + }, + LimitReqOptions: version2.LimitReqOptions{ + LogLevel: "notice", + RejectCode: 503, + }, + LimitReqs: []version2.LimitReq{ + { + ZoneName: "pol_rl_default_rateLimit-policy_default_test", + }, + }, + }, + msg: "rate limit reference", + }, + { + policyRefs: []conf_v1.PolicyReference{ + { + Name: "rateLimit-policy", + Namespace: "default", + }, + { + Name: "rateLimit-policy2", + Namespace: "default", + }, + }, + policies: map[string]*conf_v1alpha1.Policy{ + "default/rateLimit-policy": { + Spec: conf_v1alpha1.PolicySpec{ + RateLimit: &conf_v1alpha1.RateLimit{ + Key: "test", + ZoneSize: "10M", + Rate: "10r/s", + }, + }, + }, + "default/rateLimit-policy2": { + Spec: conf_v1alpha1.PolicySpec{ + RateLimit: &conf_v1alpha1.RateLimit{ + Key: "test2", + ZoneSize: "20M", + Rate: "20r/s", + }, + }, + }, + }, + expected: policiesCfg{ + LimitReqZones: []version2.LimitReqZone{ + { + Key: "test", + ZoneSize: "10M", + Rate: "10r/s", + ZoneName: "pol_rl_default_rateLimit-policy_default_test", + }, + { + Key: "test2", + ZoneSize: "20M", + Rate: "20r/s", + ZoneName: "pol_rl_default_rateLimit-policy2_default_test", + }, + }, + LimitReqOptions: version2.LimitReqOptions{ + LogLevel: "error", + RejectCode: 503, + }, + LimitReqs: []version2.LimitReq{ + { + ZoneName: "pol_rl_default_rateLimit-policy_default_test", + }, + { + ZoneName: "pol_rl_default_rateLimit-policy2_default_test", + }, + }, + }, + msg: "multi rate limit reference", + }, } vsc := newVirtualServerConfigurator(&ConfigParams{}, false, false, &StaticConfigParams{}) for _, test := range tests { - result := vsc.generatePolicies(owner, ownerNamespace, test.policyRefs, test.policies) + result := vsc.generatePolicies(owner, ownerNamespace, vsNamespace, vsName, test.policyRefs, test.policies) if !reflect.DeepEqual(result, test.expected) { t.Errorf("generatePolicies() returned \n%+v but expected \n%+v for the case of %s", result, test.expected, test.msg) @@ -1846,12 +1954,18 @@ func TestGeneratePolicies(t *testing.T) { func TestGeneratePoliciesFails(t *testing.T) { var owner runtime.Object // nil is OK for the unit test ownerNamespace := "default" + vsNamespace := "default" + vsName := "test" + + dryRunOverride := true + rejectCodeOverride := 505 tests := []struct { - policyRefs []conf_v1.PolicyReference - policies map[string]*conf_v1alpha1.Policy - expected policiesCfg - msg string + policyRefs []conf_v1.PolicyReference + policies map[string]*conf_v1alpha1.Policy + expected policiesCfg + expectedWarnings Warnings + msg string }{ { policyRefs: []conf_v1.PolicyReference{ @@ -1862,12 +1976,15 @@ func TestGeneratePoliciesFails(t *testing.T) { }, policies: map[string]*conf_v1alpha1.Policy{}, expected: policiesCfg{ - Allow: []string{}, - Deny: []string{}, ErrorReturn: &version2.Return{ Code: 500, }, }, + expectedWarnings: map[runtime.Object][]string{ + nil: { + "Policy default/allow-policy is missing or invalid", + }, + }, msg: "missing policy", }, { @@ -1899,20 +2016,137 @@ func TestGeneratePoliciesFails(t *testing.T) { Allow: []string{"127.0.0.1"}, Deny: []string{"127.0.0.2"}, }, + expectedWarnings: map[runtime.Object][]string{ + nil: { + "AccessControl policy (or policies) with deny rules is overridden by policy (or policies) with allow rules", + }, + }, msg: "conflicting policies", }, + { + policyRefs: []conf_v1.PolicyReference{ + { + Name: "rateLimit-policy", + Namespace: "default", + }, + { + Name: "rateLimit-policy2", + Namespace: "default", + }, + }, + policies: map[string]*conf_v1alpha1.Policy{ + "default/rateLimit-policy": { + Spec: conf_v1alpha1.PolicySpec{ + RateLimit: &conf_v1alpha1.RateLimit{ + Key: "test", + ZoneSize: "10M", + Rate: "10r/s", + }, + }, + }, + "default/rateLimit-policy2": { + Spec: conf_v1alpha1.PolicySpec{ + RateLimit: &conf_v1alpha1.RateLimit{ + Key: "test2", + ZoneSize: "20M", + Rate: "20r/s", + DryRun: &dryRunOverride, + LogLevel: "info", + RejectCode: &rejectCodeOverride, + }, + }, + }, + }, + expected: policiesCfg{ + LimitReqZones: []version2.LimitReqZone{ + { + Key: "test", + ZoneSize: "10M", + Rate: "10r/s", + ZoneName: "pol_rl_default_rateLimit-policy_default_test", + }, + { + Key: "test2", + ZoneSize: "20M", + Rate: "20r/s", + ZoneName: "pol_rl_default_rateLimit-policy2_default_test", + }, + }, + LimitReqOptions: version2.LimitReqOptions{ + LogLevel: "error", + RejectCode: 503, + }, + LimitReqs: []version2.LimitReq{ + { + ZoneName: "pol_rl_default_rateLimit-policy_default_test", + }, + { + ZoneName: "pol_rl_default_rateLimit-policy2_default_test", + }, + }, + }, + expectedWarnings: map[runtime.Object][]string{ + nil: { + "RateLimit policy default/rateLimit-policy2 with limit request option dryRun=true is overridden to dryRun=false by the first policy reference in this context", + "RateLimit policy default/rateLimit-policy2 with limit request option logLevel=info is overridden to logLevel=error by the first policy reference in this context", + "RateLimit policy default/rateLimit-policy2 with limit request option rejectCode=505 is overridden to rejectCode=503 by the first policy reference in this context", + }, + }, + msg: "rate limit policy limit request option override", + }, } - vsc := newVirtualServerConfigurator(&ConfigParams{}, false, false, &StaticConfigParams{}) - for _, test := range tests { - result := vsc.generatePolicies(owner, ownerNamespace, test.policyRefs, test.policies) + vsc := newVirtualServerConfigurator(&ConfigParams{}, false, false, &StaticConfigParams{}) + + result := vsc.generatePolicies(owner, ownerNamespace, vsNamespace, vsName, test.policyRefs, test.policies) if !reflect.DeepEqual(result, test.expected) { t.Errorf("generatePolicies() returned \n%+v but expected \n%+v for the case of %s", result, test.expected, test.msg) } - if len(vsc.warnings) == 0 { - t.Errorf("generatePolicies() returned no warnings for the case of %s", test.msg) + if !reflect.DeepEqual(vsc.warnings, test.expectedWarnings) { + t.Errorf("generatePolicies() returned warnings of \n%v but expected \n%v for the case of %s", vsc.warnings, test.expectedWarnings, test.msg) + } + } +} + +func TestRemoveDuplicates(t *testing.T) { + tests := []struct { + rlz []version2.LimitReqZone + expected []version2.LimitReqZone + }{ + { + rlz: []version2.LimitReqZone{ + {ZoneName: "test"}, + {ZoneName: "test"}, + {ZoneName: "test2"}, + {ZoneName: "test3"}, + }, + expected: []version2.LimitReqZone{ + {ZoneName: "test"}, + {ZoneName: "test2"}, + {ZoneName: "test3"}, + }, + }, + { + rlz: []version2.LimitReqZone{ + {ZoneName: "test"}, + {ZoneName: "test"}, + {ZoneName: "test2"}, + {ZoneName: "test3"}, + {ZoneName: "test3"}, + }, + expected: []version2.LimitReqZone{ + {ZoneName: "test"}, + {ZoneName: "test2"}, + {ZoneName: "test3"}, + }, + }, + } + for _, test := range tests { + result := removeDuplicateLimitReqZones(test.rlz) + if !reflect.DeepEqual(result, test.expected) { + t.Errorf("removeDuplicates() returned \n%v, but expected \n%v", result, test.expected) } } } @@ -4047,7 +4281,7 @@ func TestUpstreamHasKeepalive(t *testing.T) { conf_v1.Upstream{Keepalive: &noKeepalive}, &ConfigParams{Keepalive: keepalive}, false, - "upstream keepalive set to 0, configparam keepive set", + "upstream keepalive set to 0, configparam keepalive set", }, { conf_v1.Upstream{Keepalive: &keepalive}, diff --git a/internal/k8s/controller_test.go b/internal/k8s/controller_test.go index 07f9b0135f..7ebbb97255 100644 --- a/internal/k8s/controller_test.go +++ b/internal/k8s/controller_test.go @@ -2105,17 +2105,17 @@ func TestGetPolicies(t *testing.T) { expectedPolicies := []*conf_v1alpha1.Policy{validPolicy} expectedErrors := []error{ - errors.New("Policy default/invalid-policy is invalid: spec: Invalid value: \"\": must specify exactly one of: `accessControl`"), + errors.New("Policy default/invalid-policy is invalid: spec: Invalid value: \"\": must specify exactly one of: `accessControl`, `rateLimit`"), errors.New("Policy nginx-ingress/valid-policy doesn't exist"), errors.New("Failed to get policy nginx-ingress/some-policy: GetByKey error"), } result, errors := lbc.getPolicies(policyRefs, "default") if !reflect.DeepEqual(result, expectedPolicies) { - t.Errorf("lbc.getPolicies() returned %v but expected %v", result, expectedPolicies) + t.Errorf("lbc.getPolicies() returned \n%v but \nexpected %v", result, expectedPolicies) } if !reflect.DeepEqual(errors, expectedErrors) { - t.Errorf("lbc.getPolicies() returned %v but expected %v", errors, expectedErrors) + t.Errorf("lbc.getPolicies() returned \n%v but expected \n%v", errors, expectedErrors) } } diff --git a/pkg/apis/configuration/v1alpha1/types.go b/pkg/apis/configuration/v1alpha1/types.go index c79c276c5d..1f04c6d776 100644 --- a/pkg/apis/configuration/v1alpha1/types.go +++ b/pkg/apis/configuration/v1alpha1/types.go @@ -114,10 +114,10 @@ type Policy struct { // PolicySpec is the spec of the Policy resource. // The spec includes multiple fields, where each field represents a different policy. -// Note: currently we have only one policy -- AccessControl, but we will support more in the future. // Only one policy (field) is allowed. type PolicySpec struct { AccessControl *AccessControl `json:"accessControl"` + RateLimit *RateLimit `json:"rateLimit"` } // AccessControl defines an access policy based on the source IP of a request. @@ -126,6 +126,19 @@ type AccessControl struct { Deny []string `json:"deny"` } +// RateLimit defines a rate limit policy. +type RateLimit struct { + Rate string `json:"rate"` + Key string `json:"key"` + Delay *int `json:"delay"` + NoDelay *bool `json:"noDelay"` + Burst *int `json:"burst"` + ZoneSize string `json:"zoneSize"` + DryRun *bool `json:"dryRun"` + LogLevel string `json:"logLevel"` + RejectCode *int `json:"rejectCode"` +} + // +k8s:deepcopy-gen:interfaces=k8s.io/apimachinery/pkg/runtime.Object // PolicyList is a list of the Policy resources. diff --git a/pkg/apis/configuration/v1alpha1/zz_generated.deepcopy.go b/pkg/apis/configuration/v1alpha1/zz_generated.deepcopy.go index 3263562e49..b6a56ea675 100644 --- a/pkg/apis/configuration/v1alpha1/zz_generated.deepcopy.go +++ b/pkg/apis/configuration/v1alpha1/zz_generated.deepcopy.go @@ -215,6 +215,11 @@ func (in *PolicySpec) DeepCopyInto(out *PolicySpec) { *out = new(AccessControl) (*in).DeepCopyInto(*out) } + if in.RateLimit != nil { + in, out := &in.RateLimit, &out.RateLimit + *out = new(RateLimit) + (*in).DeepCopyInto(*out) + } return } @@ -228,6 +233,47 @@ func (in *PolicySpec) DeepCopy() *PolicySpec { return out } +// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. +func (in *RateLimit) DeepCopyInto(out *RateLimit) { + *out = *in + if in.Delay != nil { + in, out := &in.Delay, &out.Delay + *out = new(int) + **out = **in + } + if in.NoDelay != nil { + in, out := &in.NoDelay, &out.NoDelay + *out = new(bool) + **out = **in + } + if in.Burst != nil { + in, out := &in.Burst, &out.Burst + *out = new(int) + **out = **in + } + if in.DryRun != nil { + in, out := &in.DryRun, &out.DryRun + *out = new(bool) + **out = **in + } + if in.RejectCode != nil { + in, out := &in.RejectCode, &out.RejectCode + *out = new(int) + **out = **in + } + return +} + +// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new RateLimit. +func (in *RateLimit) DeepCopy() *RateLimit { + if in == nil { + return nil + } + out := new(RateLimit) + in.DeepCopyInto(out) + return out +} + // DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. func (in *TransportServer) DeepCopyInto(out *TransportServer) { *out = *in diff --git a/pkg/apis/configuration/validation/common.go b/pkg/apis/configuration/validation/common.go new file mode 100644 index 0000000000..e38a54a024 --- /dev/null +++ b/pkg/apis/configuration/validation/common.go @@ -0,0 +1,128 @@ +package validation + +import ( + "fmt" + "regexp" + "strings" + + "k8s.io/apimachinery/pkg/util/validation" + "k8s.io/apimachinery/pkg/util/validation/field" +) + +const ( + escapedStringsFmt = `([^"\\]|\\.)*` + escapedStringsErrMsg = `must have all '"' (double quotes) escaped and must not end with an unescaped '\' (backslash)` +) + +var escapedStringsFmtRegexp = regexp.MustCompile("^" + escapedStringsFmt + "$") + +func validateVariable(nVar string, validVars map[string]bool, fieldPath *field.Path) field.ErrorList { + allErrs := field.ErrorList{} + + if !validVars[nVar] { + msg := fmt.Sprintf("'%v' contains an invalid NGINX variable. Accepted variables are: %v", nVar, mapToPrettyString(validVars)) + allErrs = append(allErrs, field.Invalid(fieldPath, nVar, msg)) + } + return allErrs +} + +func isValidSpecialVariableHeader(header string) []string { + // underscores in $http_ variable represent '-'. + errMsgs := validation.IsHTTPHeaderName(strings.Replace(header, "_", "-", -1)) + if len(errMsgs) >= 1 || strings.Contains(header, "-") { + return []string{"a valid HTTP header must consist of alphanumeric characters or '_'"} + } + return nil +} + +func validateSpecialVariable(nVar string, fieldPath *field.Path) field.ErrorList { + allErrs := field.ErrorList{} + value := strings.SplitN(nVar, "_", 2) + + switch value[0] { + case "arg": + for _, msg := range isArgumentName(value[1]) { + allErrs = append(allErrs, field.Invalid(fieldPath, nVar, msg)) + } + case "http": + for _, msg := range isValidSpecialVariableHeader(value[1]) { + allErrs = append(allErrs, field.Invalid(fieldPath, nVar, msg)) + } + case "cookie": + for _, msg := range isCookieName(value[1]) { + allErrs = append(allErrs, field.Invalid(fieldPath, nVar, msg)) + } + } + + return allErrs +} + +func validateStringWithVariables(str string, fieldPath *field.Path, specialVars []string, validVars map[string]bool) field.ErrorList { + allErrs := field.ErrorList{} + + if strings.HasSuffix(str, "$") { + return append(allErrs, field.Invalid(fieldPath, str, "must not end with $")) + } + + for i, c := range str { + if c == '$' { + msg := "variables must be enclosed in curly braces, for example ${host}" + + if str[i+1] != '{' { + return append(allErrs, field.Invalid(fieldPath, str, msg)) + } + + if !strings.Contains(str[i+1:], "}") { + return append(allErrs, field.Invalid(fieldPath, str, msg)) + } + } + } + + nginxVars := captureVariables(str) + for _, nVar := range nginxVars { + special := false + for _, specialVar := range specialVars { + if strings.HasPrefix(nVar, specialVar) { + special = true + break + } + } + + if special { + allErrs = append(allErrs, validateSpecialVariable(nVar, fieldPath)...) + } else { + allErrs = append(allErrs, validateVariable(nVar, validVars, fieldPath)...) + } + } + + return allErrs +} + +const sizeFmt = `\d+[kKmM]?` +const sizeErrMsg = "must consist of numeric characters followed by a valid size suffix. 'k|K|m|M" + +var sizeRegexp = regexp.MustCompile("^" + sizeFmt + "$") + +func validateSize(size string, fieldPath *field.Path) field.ErrorList { + allErrs := field.ErrorList{} + + if size == "" { + return allErrs + } + + if !sizeRegexp.MatchString(size) { + msg := validation.RegexError(sizeErrMsg, sizeFmt, "16", "32k", "64M") + return append(allErrs, field.Invalid(fieldPath, size, msg)) + } + return allErrs +} + +func mapToPrettyString(m map[string]bool) string { + var out []string + + for k := range m { + out = append(out, k) + } + + return strings.Join(out, ", ") +} diff --git a/pkg/apis/configuration/validation/common_test.go b/pkg/apis/configuration/validation/common_test.go new file mode 100644 index 0000000000..fcbc68ecbd --- /dev/null +++ b/pkg/apis/configuration/validation/common_test.go @@ -0,0 +1,154 @@ +package validation + +import ( + "testing" + + "k8s.io/apimachinery/pkg/util/validation/field" +) + +func createPointerFromInt(n int) *int { + return &n +} + +func TestValidateVariable(t *testing.T) { + var validVars = map[string]bool{ + "scheme": true, + "http_x_forwarded_proto": true, + "request_uri": true, + "host": true, + } + + validTests := []string{ + "scheme", + "http_x_forwarded_proto", + "request_uri", + "host", + } + for _, nVar := range validTests { + allErrs := validateVariable(nVar, validVars, field.NewPath("url")) + if len(allErrs) != 0 { + t.Errorf("validateVariable(%v) returned errors %v for valid input", nVar, allErrs) + } + } +} + +func TestValidateVariableFails(t *testing.T) { + var validVars = map[string]bool{ + "host": true, + } + invalidVars := []string{ + "", + "hostinvalid.com", + "$a", + "host${host}", + "host${host}}", + "host$${host}", + } + for _, nVar := range invalidVars { + allErrs := validateVariable(nVar, validVars, field.NewPath("url")) + if len(allErrs) == 0 { + t.Errorf("validateVariable(%v) returned no errors for invalid input", nVar) + } + } +} + +func TestValidateSpecialVariable(t *testing.T) { + specialVars := []string{"arg_username", "arg_user_name", "http_header_name", "cookie_cookie_name"} + for _, v := range specialVars { + allErrs := validateSpecialVariable(v, field.NewPath("variable")) + if len(allErrs) != 0 { + t.Errorf("validateSpecialVariable(%v) returned errors for valid case: %v", v, allErrs) + } + } +} + +func TestValidateSpecialVariableFails(t *testing.T) { + specialVars := []string{"arg_invalid%", "http_header+invalid", "cookie_cookie_name?invalid"} + for _, v := range specialVars { + allErrs := validateSpecialVariable(v, field.NewPath("variable")) + if len(allErrs) == 0 { + t.Errorf("validateSpecialVariable(%v) returned no errors for invalid case", v) + } + } +} + +func TestValidateStringWithVariables(t *testing.T) { + testStrings := []string{ + "", + "${scheme}", + "${scheme}${host}", + "foo.bar", + } + validVars := map[string]bool{"scheme": true, "host": true} + + for _, test := range testStrings { + allErrs := validateStringWithVariables(test, field.NewPath("string"), nil, validVars) + if len(allErrs) != 0 { + t.Errorf("validateStringWithVariables(%v) returned errors for valid input: %v", test, allErrs) + } + } + + specialVars := []string{"arg", "http", "cookie"} + testStringsSpecial := []string{ + "${arg_username}", + "${http_header_name}", + "${cookie_cookie_name}", + } + + for _, test := range testStringsSpecial { + allErrs := validateStringWithVariables(test, field.NewPath("string"), specialVars, validVars) + if len(allErrs) != 0 { + t.Errorf("validateStringWithVariables(%v) returned errors for valid input: %v", test, allErrs) + } + } +} + +func TestValidateStringWithVariablesFail(t *testing.T) { + testStrings := []string{ + "$scheme}", + "${sch${eme}${host}", + "host$", + "${host", + "${invalid}", + } + validVars := map[string]bool{"scheme": true, "host": true} + + for _, test := range testStrings { + allErrs := validateStringWithVariables(test, field.NewPath("string"), nil, validVars) + if len(allErrs) == 0 { + t.Errorf("validateStringWithVariables(%v) returned no errors for invalid input", test) + } + } + + specialVars := []string{"arg", "http", "cookie"} + testStringsSpecial := []string{ + "${arg_username%}", + "${http_header-name}", + "${cookie_cookie?name}", + } + + for _, test := range testStringsSpecial { + allErrs := validateStringWithVariables(test, field.NewPath("string"), specialVars, validVars) + if len(allErrs) == 0 { + t.Errorf("validateStringWithVariables(%v) returned no errors for invalid input", test) + } + } +} + +func TestValidateSize(t *testing.T) { + var validInput = []string{"", "4k", "8K", "16m", "32M"} + for _, test := range validInput { + allErrs := validateSize(test, field.NewPath("size-field")) + if len(allErrs) != 0 { + t.Errorf("validateSize(%q) returned an error for valid input", test) + } + } + + var invalidInput = []string{"55mm", "2mG", "6kb", "-5k", "1L", "5G"} + for _, test := range invalidInput { + allErrs := validateSize(test, field.NewPath("size-field")) + if len(allErrs) == 0 { + t.Errorf("validateSize(%q) didn't return error for invalid input.", test) + } + } +} diff --git a/pkg/apis/configuration/validation/policy.go b/pkg/apis/configuration/validation/policy.go index 34b4dfcd0d..af18ac3185 100644 --- a/pkg/apis/configuration/validation/policy.go +++ b/pkg/apis/configuration/validation/policy.go @@ -1,9 +1,14 @@ package validation import ( + "fmt" "net" + "regexp" + "strconv" + "strings" "github.com/nginxinc/kubernetes-ingress/pkg/apis/configuration/v1alpha1" + "k8s.io/apimachinery/pkg/util/validation" "k8s.io/apimachinery/pkg/util/validation/field" ) @@ -23,8 +28,13 @@ func validatePolicySpec(spec *v1alpha1.PolicySpec, fieldPath *field.Path) field. fieldCount++ } + if spec.RateLimit != nil { + allErrs = append(allErrs, validateRateLimit(spec.RateLimit, fieldPath.Child("rateLimit"))...) + fieldCount++ + } + if fieldCount != 1 { - allErrs = append(allErrs, field.Invalid(fieldPath, "", "must specify exactly one of: `accessControl`")) + allErrs = append(allErrs, field.Invalid(fieldPath, "", "must specify exactly one of: `accessControl`, `rateLimit`")) } return allErrs @@ -56,6 +66,121 @@ func validateAccessControl(accessControl *v1alpha1.AccessControl, fieldPath *fie return allErrs } +func validateRateLimit(rateLimit *v1alpha1.RateLimit, fieldPath *field.Path) field.ErrorList { + allErrs := field.ErrorList{} + + allErrs = append(allErrs, validateRateLimitZoneSize(rateLimit.ZoneSize, fieldPath.Child("zoneSize"))...) + allErrs = append(allErrs, validateRate(rateLimit.Rate, fieldPath.Child("rate"))...) + allErrs = append(allErrs, validateRateLimitKey(rateLimit.Key, fieldPath.Child("key"))...) + + if rateLimit.Delay != nil { + allErrs = append(allErrs, validatePositiveInt(*rateLimit.Delay, fieldPath.Child("delay"))...) + } + + if rateLimit.Burst != nil { + allErrs = append(allErrs, validatePositiveInt(*rateLimit.Burst, fieldPath.Child("burst"))...) + } + + if rateLimit.LogLevel != "" { + allErrs = append(allErrs, validateRateLimitLogLevel(rateLimit.LogLevel, fieldPath.Child("logLevel"))...) + } + + if rateLimit.RejectCode != nil { + if *rateLimit.RejectCode < 400 || *rateLimit.RejectCode > 599 { + allErrs = append(allErrs, field.Invalid(fieldPath.Child("rejectCode"), rateLimit.RejectCode, + "must be within the range [400-599]")) + } + } + + return allErrs +} + +const rateFmt = `[1-9]\d*r/[sSmM]` +const rateErrMsg = "must consist of numeric characters followed by a valid rate suffix. 'r/s|r/m" + +var rateRegexp = regexp.MustCompile("^" + rateFmt + "$") + +func validateRate(rate string, fieldPath *field.Path) field.ErrorList { + allErrs := field.ErrorList{} + + if rate == "" { + return append(allErrs, field.Required(fieldPath, "")) + } + + if !rateRegexp.MatchString(rate) { + msg := validation.RegexError(rateErrMsg, rateFmt, "16r/s", "32r/m", "64r/s") + return append(allErrs, field.Invalid(fieldPath, rate, msg)) + } + return allErrs +} + +func validateRateLimitZoneSize(zoneSize string, fieldPath *field.Path) field.ErrorList { + allErrs := field.ErrorList{} + + if zoneSize == "" { + return append(allErrs, field.Required(fieldPath, "")) + } + + allErrs = append(allErrs, validateSize(zoneSize, fieldPath)...) + + kbZoneSize := strings.TrimSuffix(strings.ToLower(zoneSize), "k") + kbZoneSizeNum, err := strconv.Atoi(kbZoneSize) + + mbZoneSize := strings.TrimSuffix(strings.ToLower(zoneSize), "m") + mbZoneSizeNum, mbErr := strconv.Atoi(mbZoneSize) + + if err == nil && kbZoneSizeNum < 32 || mbErr == nil && mbZoneSizeNum == 0 { + allErrs = append(allErrs, field.Invalid(fieldPath, zoneSize, "must be greater than 31k")) + } + + return allErrs +} + +var rateLimitKeySpecialVariables = []string{"arg_", "http_", "cookie_"} + +// rateLimitVariables includes NGINX variables allowed to be used in a rateLimit policy key. +var rateLimitKeyVariables = map[string]bool{ + "binary_remote_addr": true, + "request_uri": true, + "uri": true, + "args": true, +} + +func validateRateLimitKey(key string, fieldPath *field.Path) field.ErrorList { + allErrs := field.ErrorList{} + + if key == "" { + return append(allErrs, field.Required(fieldPath, "")) + } + + if !escapedStringsFmtRegexp.MatchString(key) { + msg := validation.RegexError(escapedStringsErrMsg, escapedStringsFmt, `Hello World! \n`, `\"${request_uri}\" is unavailable. \n`) + allErrs = append(allErrs, field.Invalid(fieldPath, key, msg)) + } + + allErrs = append(allErrs, validateStringWithVariables(key, fieldPath, rateLimitKeySpecialVariables, rateLimitKeyVariables)...) + + return allErrs +} + +var validLogLevels = map[string]bool{ + "info": true, + "notice": true, + "warn": true, + "error": true, +} + +func validateRateLimitLogLevel(logLevel string, fieldPath *field.Path) field.ErrorList { + allErrs := field.ErrorList{} + + if !validLogLevels[logLevel] { + allErrs = append(allErrs, field.Invalid(fieldPath, logLevel, fmt.Sprintf("Accepted values: %s", + mapToPrettyString(validLogLevels)))) + } + + return allErrs +} + func validateIPorCIDR(ipOrCIDR string, fieldPath *field.Path) field.ErrorList { allErrs := field.ErrorList{} @@ -73,3 +198,13 @@ func validateIPorCIDR(ipOrCIDR string, fieldPath *field.Path) field.ErrorList { return append(allErrs, field.Invalid(fieldPath, ipOrCIDR, "must be a CIDR or IP")) } + +func validatePositiveInt(n int, fieldPath *field.Path) field.ErrorList { + allErrs := field.ErrorList{} + + if n <= 0 { + return append(allErrs, field.Invalid(fieldPath, n, "must be positive")) + } + + return allErrs +} diff --git a/pkg/apis/configuration/validation/policy_test.go b/pkg/apis/configuration/validation/policy_test.go index 0accbb8ffe..f0addee9d9 100644 --- a/pkg/apis/configuration/validation/policy_test.go +++ b/pkg/apis/configuration/validation/policy_test.go @@ -31,6 +31,24 @@ func TestValidatePolicyFails(t *testing.T) { if err == nil { t.Errorf("ValidatePolicy() returned no error for invalid input") } + + multiPolicy := &v1alpha1.Policy{ + Spec: v1alpha1.PolicySpec{ + AccessControl: &v1alpha1.AccessControl{ + Allow: []string{"127.0.0.1"}, + }, + RateLimit: &v1alpha1.RateLimit{ + Key: "${uri}", + ZoneSize: "10M", + Rate: "10r/s", + }, + }, + } + + err = ValidatePolicy(multiPolicy) + if err == nil { + t.Errorf("ValidatePolicy() returned no error for invalid input") + } } func TestValidateAccessControl(t *testing.T) { @@ -98,6 +116,111 @@ func TestValidateAccessControlFails(t *testing.T) { } } +func TestValidateRateLimit(t *testing.T) { + dryRun := true + noDelay := false + + tests := []struct { + rateLimit *v1alpha1.RateLimit + msg string + }{ + { + rateLimit: &v1alpha1.RateLimit{ + Rate: "10r/s", + ZoneSize: "10M", + Key: "${request_uri}", + }, + msg: "only required fields are set", + }, + { + rateLimit: &v1alpha1.RateLimit{ + Rate: "30r/m", + Key: "${request_uri}", + Delay: createPointerFromInt(5), + NoDelay: &noDelay, + Burst: createPointerFromInt(10), + ZoneSize: "10M", + DryRun: &dryRun, + LogLevel: "info", + RejectCode: createPointerFromInt(505), + }, + msg: "ratelimit all fields set", + }, + } + for _, test := range tests { + allErrs := validateRateLimit(test.rateLimit, field.NewPath("rateLimit")) + if len(allErrs) > 0 { + t.Errorf("validateRateLimit() returned errors %v for valid input for the case of %v", allErrs, test.msg) + } + } +} + +func createInvalidRateLimit(f func(r *v1alpha1.RateLimit)) *v1alpha1.RateLimit { + validRateLimit := &v1alpha1.RateLimit{ + Rate: "10r/s", + ZoneSize: "10M", + Key: "${request_uri}", + } + f(validRateLimit) + return validRateLimit +} + +func TestValidateRateLimitFails(t *testing.T) { + tests := []struct { + rateLimit *v1alpha1.RateLimit + msg string + }{ + { + rateLimit: createInvalidRateLimit(func(r *v1alpha1.RateLimit) { + r.Rate = "0r/s" + }), + msg: "invalid rateLimit rate", + }, + { + rateLimit: createInvalidRateLimit(func(r *v1alpha1.RateLimit) { + r.Key = "${fail}" + }), + msg: "invalid rateLimit key variable use", + }, + { + rateLimit: createInvalidRateLimit(func(r *v1alpha1.RateLimit) { + r.Delay = createPointerFromInt(0) + }), + msg: "invalid rateLimit delay", + }, + { + rateLimit: createInvalidRateLimit(func(r *v1alpha1.RateLimit) { + r.Burst = createPointerFromInt(0) + }), + msg: "invalid rateLimit burst", + }, + { + rateLimit: createInvalidRateLimit(func(r *v1alpha1.RateLimit) { + r.ZoneSize = "31k" + }), + msg: "invalid rateLimit zoneSize", + }, + { + rateLimit: createInvalidRateLimit(func(r *v1alpha1.RateLimit) { + r.RejectCode = createPointerFromInt(600) + }), + msg: "invalid rateLimit rejectCode", + }, + { + rateLimit: createInvalidRateLimit(func(r *v1alpha1.RateLimit) { + r.LogLevel = "invalid" + }), + msg: "invalid rateLimit logLevel", + }, + } + for _, test := range tests { + allErrs := validateRateLimit(test.rateLimit, field.NewPath("rateLimit")) + if len(allErrs) == 0 { + t.Errorf("validateRateLimit() returned no errors for invalid input for the case of %v", test.msg) + } + } +} + func TestValidateIPorCIDR(t *testing.T) { validInput := []string{ "192.168.1.1", @@ -127,3 +250,92 @@ func TestValidateIPorCIDR(t *testing.T) { } } } + +func TestValidateRate(t *testing.T) { + validInput := []string{ + "10r/s", + "100r/m", + "1r/s", + } + + for _, input := range validInput { + allErrs := validateRate(input, field.NewPath("rate")) + if len(allErrs) > 0 { + t.Errorf("validateRate(%q) returned errors %v for valid input", input, allErrs) + } + } + + invalidInput := []string{ + "10s", + "10r/", + "10r/ms", + "0r/s", + } + + for _, input := range invalidInput { + allErrs := validateRate(input, field.NewPath("rate")) + if len(allErrs) == 0 { + t.Errorf("validateRate(%q) returned no errors for invalid input", input) + } + } +} + +func TestValidatePositiveInt(t *testing.T) { + validInput := []int{1, 2} + + for _, input := range validInput { + allErrs := validatePositiveInt(input, field.NewPath("int")) + if len(allErrs) > 0 { + t.Errorf("validatePositiveInt(%q) returned errors %v for valid input", input, allErrs) + } + } + + invalidInput := []int{-1, 0} + + for _, input := range invalidInput { + allErrs := validatePositiveInt(input, field.NewPath("int")) + if len(allErrs) == 0 { + t.Errorf("validatePositiveInt(%q) returned no errors for invalid input", input) + } + } +} + +func TestValidateRateLimitZoneSize(t *testing.T) { + var validInput = []string{"32", "32k", "32K", "10m"} + + for _, test := range validInput { + allErrs := validateRateLimitZoneSize(test, field.NewPath("size")) + if len(allErrs) != 0 { + t.Errorf("validateRateLimitZoneSize(%q) returned an error for valid input", test) + } + } + + var invalidInput = []string{"", "31", "31k", "0", "0M"} + + for _, test := range invalidInput { + allErrs := validateRateLimitZoneSize(test, field.NewPath("size")) + if len(allErrs) == 0 { + t.Errorf("validateRateLimitZoneSize(%q) didn't return error for invalid input", test) + } + } +} + +func TestValidateRateLimitLogLevel(t *testing.T) { + var validInput = []string{"error", "info", "warn", "notice"} + + for _, test := range validInput { + allErrs := validateRateLimitLogLevel(test, field.NewPath("logLevel")) + if len(allErrs) != 0 { + t.Errorf("validateRateLimitLogLevel(%q) returned an error for valid input", test) + } + } + + var invalidInput = []string{"warn ", "info error", ""} + + for _, test := range invalidInput { + allErrs := validateRateLimitLogLevel(test, field.NewPath("logLevel")) + if len(allErrs) == 0 { + t.Errorf("validateRateLimitLogLevel(%q) didn't return error for invalid input", test) + } + } +} diff --git a/pkg/apis/configuration/validation/virtualserver.go b/pkg/apis/configuration/validation/virtualserver.go index 91af879033..507317f8a4 100644 --- a/pkg/apis/configuration/validation/virtualserver.go +++ b/pkg/apis/configuration/validation/virtualserver.go @@ -13,13 +13,6 @@ import ( "k8s.io/apimachinery/pkg/util/validation/field" ) -const ( - escapedStringsFmt = `([^"\\]|\\.)*` - escapedStringsErrMsg = `must have all '"' (double quotes) escaped and must not end with an unescaped '\' (backslash)` -) - -var escapedStringsFmtRegexp = regexp.MustCompile("^" + escapedStringsFmt + "$") - // ValidateVirtualServer validates a VirtualServer. func ValidateVirtualServer(virtualServer *v1.VirtualServer, isPlus bool) error { allErrs := validateVirtualServerSpec(&virtualServer.Spec, field.NewPath("spec"), isPlus, virtualServer.Namespace) @@ -202,25 +195,6 @@ func validateOffset(offset string, fieldPath *field.Path) field.ErrorList { return allErrs } -const sizeFmt = `\d+[kKmM]?` -const sizeErrMsg = "must consist of numeric characters followed by a valid size suffix. 'k|K|m|M" - -var sizeRegexp = regexp.MustCompile("^" + sizeFmt + "$") - -func validateSize(size string, fieldPath *field.Path) field.ErrorList { - allErrs := field.ErrorList{} - - if size == "" { - return allErrs - } - - if !sizeRegexp.MatchString(size) { - msg := validation.RegexError(sizeErrMsg, sizeFmt, "16", "32k", "64M") - return append(allErrs, field.Invalid(fieldPath, size, msg)) - } - return allErrs -} - func validateBuffer(buff *v1.UpstreamBuffers, fieldPath *field.Path) field.ErrorList { allErrs := field.ErrorList{} @@ -863,88 +837,6 @@ func validateRedirectURL(redirectURL string, fieldPath *field.Path, validVars ma return allErrs } -func validateStringWithVariables(str string, fieldPath *field.Path, specialVars []string, validVars map[string]bool) field.ErrorList { - allErrs := field.ErrorList{} - - if strings.HasSuffix(str, "$") { - return append(allErrs, field.Invalid(fieldPath, str, "must not end with $")) - } - - for i, c := range str { - if c == '$' { - msg := "variables must be enclosed in curly braces, for example ${host}" - - if str[i+1] != '{' { - return append(allErrs, field.Invalid(fieldPath, str, msg)) - } - - if !strings.Contains(str[i+1:], "}") { - return append(allErrs, field.Invalid(fieldPath, str, msg)) - } - } - } - - nginxVars := captureVariables(str) - for _, nVar := range nginxVars { - special := false - for _, specialVar := range specialVars { - if strings.HasPrefix(nVar, specialVar) { - special = true - break - } - } - - if special { - allErrs = append(allErrs, validateSpecialVariable(nVar, fieldPath)...) - } else { - allErrs = append(allErrs, validateVariable(nVar, validVars, fieldPath)...) - } - } - - return allErrs -} - -func validateVariable(nVar string, validVars map[string]bool, fieldPath *field.Path) field.ErrorList { - allErrs := field.ErrorList{} - - if !validVars[nVar] { - msg := fmt.Sprintf("'%v' contains an invalid NGINX variable. Accepted variables are: %v", nVar, mapToPrettyString(validVars)) - allErrs = append(allErrs, field.Invalid(fieldPath, nVar, msg)) - } - return allErrs -} - -func isValidSpecialVariableHeader(header string) []string { - // underscores in $http_ variable represent '-'. - errMsgs := validation.IsHTTPHeaderName(strings.Replace(header, "_", "-", -1)) - if len(errMsgs) >= 1 || strings.Contains(header, "-") { - return []string{"a valid HTTP header must consist of alphanumeric characters or '_'"} - } - return nil -} - -func validateSpecialVariable(nVar string, fieldPath *field.Path) field.ErrorList { - allErrs := field.ErrorList{} - value := strings.SplitN(nVar, "_", 2) - - switch value[0] { - case "arg": - for _, msg := range isArgumentName(value[1]) { - allErrs = append(allErrs, field.Invalid(fieldPath, nVar, msg)) - } - case "http": - for _, msg := range isValidSpecialVariableHeader(value[1]) { - allErrs = append(allErrs, field.Invalid(fieldPath, nVar, msg)) - } - case "cookie": - for _, msg := range isCookieName(value[1]) { - allErrs = append(allErrs, field.Invalid(fieldPath, nVar, msg)) - } - } - - return allErrs -} - func validateActionReturnCode(code int, fieldPath *field.Path) field.ErrorList { allErrs := field.ErrorList{} @@ -1005,16 +897,6 @@ func validateActionReturnType(returnType string, fieldPath *field.Path) field.Er return allErrs } -func mapToPrettyString(m map[string]bool) string { - var out []string - - for k := range m { - out = append(out, k) - } - - return strings.Join(out, ", ") -} - func validateRouteField(value string, fieldPath *field.Path) field.ErrorList { allErrs := field.ErrorList{} diff --git a/pkg/apis/configuration/validation/virtualserver_test.go b/pkg/apis/configuration/validation/virtualserver_test.go index 9c5d946cd1..5acb1b2ff6 100644 --- a/pkg/apis/configuration/validation/virtualserver_test.go +++ b/pkg/apis/configuration/validation/virtualserver_test.go @@ -2147,10 +2147,6 @@ func TestValidateUpstreamLBMethodFails(t *testing.T) { } } -func createPointerFromInt(n int) *int { - return &n -} - func TestValidatePositiveIntOrZeroFromPointer(t *testing.T) { tests := []struct { number *int @@ -2277,24 +2273,6 @@ func TestValidateBuffer(t *testing.T) { } } -func TestValidateSize(t *testing.T) { - var validInput = []string{"", "4k", "8K", "16m", "32M"} - for _, test := range validInput { - allErrs := validateSize(test, field.NewPath("size-field")) - if len(allErrs) != 0 { - t.Errorf("validateSize(%q) returned an error for valid input", test) - } - } - - var invalidInput = []string{"55mm", "2mG", "6kb", "-5k", "1L", "5G"} - for _, test := range invalidInput { - allErrs := validateSize(test, field.NewPath("size-field")) - if len(allErrs) == 0 { - t.Errorf("validateSize(%q) didn't return error for invalid input.", test) - } - } -} - func TestValidateTimeFails(t *testing.T) { time := "invalid" allErrs := validateTime(time, field.NewPath("time-field")) @@ -2734,53 +2712,6 @@ func TestValidateRedirectStatusCodeFails(t *testing.T) { } } -func TestValidateVariable(t *testing.T) { - var validVars = map[string]bool{ - "scheme": true, - "http_x_forwarded_proto": true, - "request_uri": true, - "host": true, - } - - tests := []struct { - nVar string - }{ - {"scheme"}, - {"http_x_forwarded_proto"}, - {"request_uri"}, - {"host"}, - } - for _, test := range tests { - allErrs := validateVariable(test.nVar, validVars, field.NewPath("url")) - if len(allErrs) != 0 { - t.Errorf("validateVariable(%v) returned errors %v for valid input", test.nVar, allErrs) - } - } -} - -func TestValidateVariableFails(t *testing.T) { - var validVars = map[string]bool{ - "host": true, - } - - tests := []struct { - nVar string - }{ - {""}, - {"hostinvalid.com"}, - {"$a"}, - {"host${host}"}, - {"host${host}}"}, - {"host$${host}"}, - } - for _, test := range tests { - allErrs := validateVariable(test.nVar, validVars, field.NewPath("url")) - if len(allErrs) == 0 { - t.Errorf("validateVariable(%v) returned no errors for invalid input", test.nVar) - } - } -} - func TestIsRegexOrExactMatch(t *testing.T) { tests := []struct { path string @@ -3281,69 +3212,6 @@ func TestValidateStringNoVariablesFails(t *testing.T) { } } -func TestValidateStringWithVariables(t *testing.T) { - testStrings := []string{ - "", - "${scheme}", - "${scheme}${host}", - "foo.bar", - } - validVars := map[string]bool{"scheme": true, "host": true} - - for _, test := range testStrings { - allErrs := validateStringWithVariables(test, field.NewPath("string"), nil, validVars) - if len(allErrs) != 0 { - t.Errorf("validateStringWithVariables(%v) returned errors for valid input: %v", test, allErrs) - } - } - - specialVars := []string{"arg", "http", "cookie"} - testStringsSpecial := []string{ - "${arg_username}", - "${http_header_name}", - "${cookie_cookie_name}", - } - - for _, test := range testStringsSpecial { - allErrs := validateStringWithVariables(test, field.NewPath("string"), specialVars, validVars) - if len(allErrs) != 0 { - t.Errorf("validateStringWithVariables(%v) returned errors for valid input: %v", test, allErrs) - } - } -} - -func TestValidateStringWithVariablesFail(t *testing.T) { - testStrings := []string{ - "$scheme}", - "${sch${eme}${host}", - "host$", - "${host", - "${invalid}", - } - validVars := map[string]bool{"scheme": true, "host": true} - - for _, test := range testStrings { - allErrs := validateStringWithVariables(test, field.NewPath("string"), nil, validVars) - if len(allErrs) == 0 { - t.Errorf("validateStringWithVariables(%v) returned no errors for invalid input", test) - } - } - - specialVars := []string{"arg", "http", "cookie"} - testStringsSpecial := []string{ - "${arg_username%}", - "${http_header-name}", - "${cookie_cookie?name}", - } - - for _, test := range testStringsSpecial { - allErrs := validateStringWithVariables(test, field.NewPath("string"), specialVars, validVars) - if len(allErrs) == 0 { - t.Errorf("validateStringWithVariables(%v) returned no errors for invalid input", test) - } - } -} - func TestValidateActionReturnCode(t *testing.T) { codes := []int{200, 201, 400, 404, 500, 502, 599} for _, c := range codes { @@ -3364,26 +3232,6 @@ func TestValidateActionReturnCodeFails(t *testing.T) { } } -func TestValidateSpecialVariable(t *testing.T) { - specialVars := []string{"arg_username", "arg_user_name", "http_header_name", "cookie_cookie_name"} - for _, v := range specialVars { - allErrs := validateSpecialVariable(v, field.NewPath("variable")) - if len(allErrs) != 0 { - t.Errorf("validateSpecialVariable(%v) returned errors for valid case: %v", v, allErrs) - } - } -} - -func TestValidateSpecialVariableFails(t *testing.T) { - specialVars := []string{"arg_invalid%", "http_header+invalid", "cookie_cookie_name?invalid"} - for _, v := range specialVars { - allErrs := validateSpecialVariable(v, field.NewPath("variable")) - if len(allErrs) == 0 { - t.Errorf("validateSpecialVariable(%v) returned no errors for invalid case", v) - } - } -} - func TestErrorPageHasRequiredFields(t *testing.T) { tests := []struct { errorPage v1.ErrorPage