Skip to content

Commit

Permalink
Support variables in action headers
Browse files Browse the repository at this point in the history
  • Loading branch information
pleshakov committed Sep 25, 2020
1 parent 46da41a commit 6cb6a80
Show file tree
Hide file tree
Showing 8 changed files with 685 additions and 179 deletions.
2 changes: 2 additions & 0 deletions cmd/nginx-ingress/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -580,6 +580,7 @@ func main() {
controllerNamespace := os.Getenv("POD_NAMESPACE")

transportServerValidator := cr_validation.NewTransportServerValidator(*enableTLSPassthrough)
virtualServerValidator := cr_validation.NewVirtualServerValidator(*nginxPlus)

lbcInput := k8s.NewLoadBalancerControllerInput{
KubeClient: kubeClient,
Expand All @@ -605,6 +606,7 @@ func main() {
MetricsCollector: controllerCollector,
GlobalConfigurationValidator: globalConfigurationValidator,
TransportServerValidator: transportServerValidator,
VirtualServerValidator: virtualServerValidator,
SpireAgentAddress: *spireAgentAddress,
InternalRoutesEnabled: *enableInternalRoutes,
IsLatencyMetricsEnabled: *enableLatencyMetrics,
Expand Down
13 changes: 8 additions & 5 deletions internal/k8s/controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,7 @@ type LoadBalancerController struct {
metricsCollector collectors.ControllerCollector
globalConfigurationValidator *validation.GlobalConfigurationValidator
transportServerValidator *validation.TransportServerValidator
virtualServerValidator *validation.VirtualServerValidator
spiffeController *spiffeController
internalRoutesEnabled bool
syncLock sync.Mutex
Expand Down Expand Up @@ -185,6 +186,7 @@ type NewLoadBalancerControllerInput struct {
MetricsCollector collectors.ControllerCollector
GlobalConfigurationValidator *validation.GlobalConfigurationValidator
TransportServerValidator *validation.TransportServerValidator
VirtualServerValidator *validation.VirtualServerValidator
SpireAgentAddress string
InternalRoutesEnabled bool
IsLatencyMetricsEnabled bool
Expand Down Expand Up @@ -213,6 +215,7 @@ func NewLoadBalancerController(input NewLoadBalancerControllerInput) *LoadBalanc
metricsCollector: input.MetricsCollector,
globalConfigurationValidator: input.GlobalConfigurationValidator,
transportServerValidator: input.TransportServerValidator,
virtualServerValidator: input.VirtualServerValidator,
internalRoutesEnabled: input.InternalRoutesEnabled,
isLatencyMetricsEnabled: input.IsLatencyMetricsEnabled,
}
Expand Down Expand Up @@ -1111,7 +1114,7 @@ func (lbc *LoadBalancerController) syncVirtualServer(task task) {
glog.V(2).Infof("Adding or Updating VirtualServer: %v\n", key)
vs := obj.(*conf_v1.VirtualServer)

validationErr := validation.ValidateVirtualServer(vs, lbc.isNginxPlus)
validationErr := lbc.virtualServerValidator.ValidateVirtualServer(vs)
if validationErr != nil {
err := lbc.configurator.DeleteVirtualServer(key)
if err != nil {
Expand Down Expand Up @@ -1275,7 +1278,7 @@ func (lbc *LoadBalancerController) syncVirtualServerRoute(task task) {

vsr := obj.(*conf_v1.VirtualServerRoute)

validationErr := validation.ValidateVirtualServerRoute(vsr, lbc.isNginxPlus)
validationErr := lbc.virtualServerValidator.ValidateVirtualServerRoute(vsr)
if validationErr != nil {
reason := "Rejected"
msg := fmt.Sprintf("VirtualServerRoute %s is invalid and was rejected: %v", key, validationErr)
Expand Down Expand Up @@ -2219,7 +2222,7 @@ func (lbc *LoadBalancerController) getVirtualServers() []*conf_v1.VirtualServer
continue
}

err := validation.ValidateVirtualServer(vs, lbc.isNginxPlus)
err := lbc.virtualServerValidator.ValidateVirtualServer(vs)
if err != nil {
glog.V(3).Infof("Skipping invalid VirtualServer %s/%s: %v", vs.Namespace, vs.Name, err)
continue
Expand All @@ -2242,7 +2245,7 @@ func (lbc *LoadBalancerController) getVirtualServerRoutes() []*conf_v1.VirtualSe
continue
}

err := validation.ValidateVirtualServerRoute(vsr, lbc.isNginxPlus)
err := lbc.virtualServerValidator.ValidateVirtualServerRoute(vsr)
if err != nil {
glog.V(3).Infof("Skipping invalid VirtualServerRoute %s/%s: %v", vsr.Namespace, vsr.Name, err)
continue
Expand Down Expand Up @@ -2703,7 +2706,7 @@ func (lbc *LoadBalancerController) createVirtualServer(virtualServer *conf_v1.Vi
continue
}

err = validation.ValidateVirtualServerRouteForVirtualServer(vsr, virtualServer.Spec.Host, r.Path, lbc.isNginxPlus)
err = lbc.virtualServerValidator.ValidateVirtualServerRouteForVirtualServer(vsr, virtualServer.Spec.Host, r.Path)
if err != nil {
glog.Warningf("VirtualServer %s/%s references invalid VirtualServerRoute %s: %v", virtualServer.Name, virtualServer.Namespace, vsrKey, err)
virtualServerRouteErrors = append(virtualServerRouteErrors, newVirtualServerRouteErrorFromVSR(vsr, err))
Expand Down
71 changes: 54 additions & 17 deletions pkg/apis/configuration/validation/common.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,38 +26,75 @@ func validateVariable(nVar string, validVars map[string]bool, fieldPath *field.P
return allErrs
}

func isValidSpecialVariableHeader(header string) []string {
// underscores in $http_ variable represent '-'.
errMsgs := validation.IsHTTPHeaderName(strings.Replace(header, "_", "-", -1))
if len(errMsgs) >= 1 || strings.Contains(header, "-") {
return []string{"a valid HTTP header must consist of alphanumeric characters or '_'"}
// isValidSpecialHeaderLikeVariable validates special variables $http_, $jwt_header_, $jwt_claim_
func isValidSpecialHeaderLikeVariable(value string) []string {
// underscores in a header-like variable represent '-'.
errMsgs := validation.IsHTTPHeaderName(strings.Replace(value, "_", "-", -1))
if len(errMsgs) >= 1 || strings.Contains(value, "-") {
return []string{"a valid special variable must consists of alphanumeric characters or '_'"}
}
return nil
}

func validateSpecialVariable(nVar string, fieldPath *field.Path) field.ErrorList {
func parseSpecialVariable(nVar string, fieldPath *field.Path) (name string, value string, allErrs field.ErrorList) {
// parse NGINX Plus variables
if strings.HasPrefix(nVar, "jwt_header") || strings.HasPrefix(nVar, "jwt_claim") {
parts := strings.SplitN(nVar, "_", 3)
if len(parts) != 3 {
allErrs = append(allErrs, field.Invalid(fieldPath, nVar, "is invalid variable"))
return name, value, allErrs
}

// ex: jwt_header_name_one -> jwt_header, name_one
return strings.Join(parts[:2], "_"), parts[2], allErrs
}

// parse common NGINX and NGINX Plus variables
parts := strings.SplitN(nVar, "_", 2)
if len(parts) != 2 {
allErrs = append(allErrs, field.Invalid(fieldPath, nVar, "is invalid variable"))
return name, value, allErrs
}

// ex: http_name_one -> http, name_one
return parts[0], parts[1], allErrs
}

func validateSpecialVariable(nVar string, fieldPath *field.Path, isPlus bool) field.ErrorList {
allErrs := field.ErrorList{}
value := strings.SplitN(nVar, "_", 2)

switch value[0] {
case "arg":
for _, msg := range isArgumentName(value[1]) {
name, value, allErrs := parseSpecialVariable(nVar, fieldPath)
if len(allErrs) > 0 {
return allErrs
}

addErrors := func(errors []string) {
for _, msg := range errors {
allErrs = append(allErrs, field.Invalid(fieldPath, nVar, msg))
}
}

switch name {
case "arg":
addErrors(isArgumentName(value))
case "http":
for _, msg := range isValidSpecialVariableHeader(value[1]) {
allErrs = append(allErrs, field.Invalid(fieldPath, nVar, msg))
}
addErrors(isValidSpecialHeaderLikeVariable(value))
case "cookie":
for _, msg := range isCookieName(value[1]) {
allErrs = append(allErrs, field.Invalid(fieldPath, nVar, msg))
addErrors(isCookieName(value))
case "jwt_header", "jwt_claim":
if !isPlus {
allErrs = append(allErrs, field.Forbidden(fieldPath, "is only supported in NGINX Plus"))
} else {
addErrors(isValidSpecialHeaderLikeVariable(value))
}
default:
allErrs = append(allErrs, field.Invalid(fieldPath, nVar, "unknown special variable"))
}

return allErrs
}

func validateStringWithVariables(str string, fieldPath *field.Path, specialVars []string, validVars map[string]bool) field.ErrorList {
func validateStringWithVariables(str string, fieldPath *field.Path, specialVars []string, validVars map[string]bool, isPlus bool) field.ErrorList {
allErrs := field.ErrorList{}

if strings.HasSuffix(str, "$") {
Expand Down Expand Up @@ -89,7 +126,7 @@ func validateStringWithVariables(str string, fieldPath *field.Path, specialVars
}

if special {
allErrs = append(allErrs, validateSpecialVariable(nVar, fieldPath)...)
allErrs = append(allErrs, validateSpecialVariable(nVar, fieldPath, isPlus)...)
} else {
allErrs = append(allErrs, validateVariable(nVar, validVars, fieldPath)...)
}
Expand Down
147 changes: 139 additions & 8 deletions pkg/apis/configuration/validation/common_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,27 +52,156 @@ func TestValidateVariableFails(t *testing.T) {
}
}

func TestParseSpecialVariable(t *testing.T) {
tests := []struct {
specialVar string
expectedName string
expectedValue string
}{
{
specialVar: "arg_username",
expectedName: "arg",
expectedValue: "username",
},
{
specialVar: "arg_user_name",
expectedName: "arg",
expectedValue: "user_name",
},
{
specialVar: "jwt_header_username",
expectedName: "jwt_header",
expectedValue: "username",
},
{
specialVar: "jwt_header_user_name",
expectedName: "jwt_header",
expectedValue: "user_name",
},
{
specialVar: "jwt_claim_username",
expectedName: "jwt_claim",
expectedValue: "username",
},
{
specialVar: "jwt_claim_user_name",
expectedName: "jwt_claim",
expectedValue: "user_name",
},
}

for _, test := range tests {
name, value, allErrs := parseSpecialVariable(test.specialVar, field.NewPath("variable"))
if name != test.expectedName {
t.Errorf("parseSpecialVariable(%v) returned name %v but expected %v", test.specialVar, name, test.expectedName)
}
if value != test.expectedValue {
t.Errorf("parseSpecialVariable(%v) returned value %v but expected %v", test.specialVar, value, test.expectedValue)
}
if len(allErrs) != 0 {
t.Errorf("parseSpecialVariable(%v) returned errors for valid case: %v", test.specialVar, allErrs)
}
}
}

func TestParseSpecialVariableFails(t *testing.T) {
specialVars := []string{
"arg",
"jwt_header",
"jwt_claim",
}

for _, v := range specialVars {
_, _, allErrs := parseSpecialVariable(v, field.NewPath("variable"))
if len(allErrs) == 0 {
t.Errorf("parseSpecialVariable(%v) returned no errors for invalid case", v)
}
}
}

func TestValidateSpecialVariable(t *testing.T) {
specialVars := []string{"arg_username", "arg_user_name", "http_header_name", "cookie_cookie_name"}
specialVars := []string{
"arg_username",
"arg_user_name",
"http_header_name",
"cookie_cookie_name",
}

isPlus := false

for _, v := range specialVars {
allErrs := validateSpecialVariable(v, field.NewPath("variable"))
allErrs := validateSpecialVariable(v, field.NewPath("variable"), isPlus)
if len(allErrs) != 0 {
t.Errorf("validateSpecialVariable(%v) returned errors for valid case: %v", v, allErrs)
}
}
}

func TestValidateSpecialVariableForPlus(t *testing.T) {
specialVars := []string{
"arg_username",
"arg_user_name",
"http_header_name",
"cookie_cookie_name",
"jwt_header_alg",
"jwt_claim_user",
}

isPlus := true

for _, v := range specialVars {
allErrs := validateSpecialVariable(v, field.NewPath("variable"), isPlus)
if len(allErrs) != 0 {
t.Errorf("validateSpecialVariable(%v) returned errors for valid case: %v", v, allErrs)
}
}
}

func TestValidateSpecialVariableFails(t *testing.T) {
specialVars := []string{"arg_invalid%", "http_header+invalid", "cookie_cookie_name?invalid"}
specialVars := []string{
"arg",
"arg_invalid%",
"http_header+invalid",
"cookie_cookie_name?invalid",
"jwt_header_alg",
"jwt_claim_user",
"some_var",
}

isPlus := false

for _, v := range specialVars {
allErrs := validateSpecialVariable(v, field.NewPath("variable"), isPlus)
if len(allErrs) == 0 {
t.Errorf("validateSpecialVariable(%v) returned no errors for invalid case", v)
}
}
}

func TestValidateSpecialVariableForPlusFails(t *testing.T) {
specialVars := []string{
"arg",
"arg_invalid%",
"http_header+invalid",
"cookie_cookie_name?invalid",
"jwt_header_+invalid",
"wt_claim_invalid?",
"some_var",
}

isPlus := true

for _, v := range specialVars {
allErrs := validateSpecialVariable(v, field.NewPath("variable"))
allErrs := validateSpecialVariable(v, field.NewPath("variable"), isPlus)
if len(allErrs) == 0 {
t.Errorf("validateSpecialVariable(%v) returned no errors for invalid case", v)
}
}
}

func TestValidateStringWithVariables(t *testing.T) {
isPlus := false

testStrings := []string{
"",
"${scheme}",
Expand All @@ -82,7 +211,7 @@ func TestValidateStringWithVariables(t *testing.T) {
validVars := map[string]bool{"scheme": true, "host": true}

for _, test := range testStrings {
allErrs := validateStringWithVariables(test, field.NewPath("string"), nil, validVars)
allErrs := validateStringWithVariables(test, field.NewPath("string"), nil, validVars, isPlus)
if len(allErrs) != 0 {
t.Errorf("validateStringWithVariables(%v) returned errors for valid input: %v", test, allErrs)
}
Expand All @@ -96,14 +225,16 @@ func TestValidateStringWithVariables(t *testing.T) {
}

for _, test := range testStringsSpecial {
allErrs := validateStringWithVariables(test, field.NewPath("string"), specialVars, validVars)
allErrs := validateStringWithVariables(test, field.NewPath("string"), specialVars, validVars, isPlus)
if len(allErrs) != 0 {
t.Errorf("validateStringWithVariables(%v) returned errors for valid input: %v", test, allErrs)
}
}
}

func TestValidateStringWithVariablesFail(t *testing.T) {
isPlus := false

testStrings := []string{
"$scheme}",
"${sch${eme}${host}",
Expand All @@ -114,7 +245,7 @@ func TestValidateStringWithVariablesFail(t *testing.T) {
validVars := map[string]bool{"scheme": true, "host": true}

for _, test := range testStrings {
allErrs := validateStringWithVariables(test, field.NewPath("string"), nil, validVars)
allErrs := validateStringWithVariables(test, field.NewPath("string"), nil, validVars, isPlus)
if len(allErrs) == 0 {
t.Errorf("validateStringWithVariables(%v) returned no errors for invalid input", test)
}
Expand All @@ -128,7 +259,7 @@ func TestValidateStringWithVariablesFail(t *testing.T) {
}

for _, test := range testStringsSpecial {
allErrs := validateStringWithVariables(test, field.NewPath("string"), specialVars, validVars)
allErrs := validateStringWithVariables(test, field.NewPath("string"), specialVars, validVars, isPlus)
if len(allErrs) == 0 {
t.Errorf("validateStringWithVariables(%v) returned no errors for invalid input", test)
}
Expand Down
Loading

0 comments on commit 6cb6a80

Please sign in to comment.