Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support variables in action proxy headers #1158

Merged
merged 1 commit into from
Sep 25, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions cmd/nginx-ingress/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -580,6 +580,7 @@ func main() {
controllerNamespace := os.Getenv("POD_NAMESPACE")

transportServerValidator := cr_validation.NewTransportServerValidator(*enableTLSPassthrough)
virtualServerValidator := cr_validation.NewVirtualServerValidator(*nginxPlus)

lbcInput := k8s.NewLoadBalancerControllerInput{
KubeClient: kubeClient,
Expand All @@ -605,6 +606,7 @@ func main() {
MetricsCollector: controllerCollector,
GlobalConfigurationValidator: globalConfigurationValidator,
TransportServerValidator: transportServerValidator,
VirtualServerValidator: virtualServerValidator,
SpireAgentAddress: *spireAgentAddress,
InternalRoutesEnabled: *enableInternalRoutes,
IsLatencyMetricsEnabled: *enableLatencyMetrics,
Expand Down
13 changes: 8 additions & 5 deletions internal/k8s/controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,7 @@ type LoadBalancerController struct {
metricsCollector collectors.ControllerCollector
globalConfigurationValidator *validation.GlobalConfigurationValidator
transportServerValidator *validation.TransportServerValidator
virtualServerValidator *validation.VirtualServerValidator
spiffeController *spiffeController
internalRoutesEnabled bool
syncLock sync.Mutex
Expand Down Expand Up @@ -185,6 +186,7 @@ type NewLoadBalancerControllerInput struct {
MetricsCollector collectors.ControllerCollector
GlobalConfigurationValidator *validation.GlobalConfigurationValidator
TransportServerValidator *validation.TransportServerValidator
VirtualServerValidator *validation.VirtualServerValidator
SpireAgentAddress string
InternalRoutesEnabled bool
IsLatencyMetricsEnabled bool
Expand Down Expand Up @@ -213,6 +215,7 @@ func NewLoadBalancerController(input NewLoadBalancerControllerInput) *LoadBalanc
metricsCollector: input.MetricsCollector,
globalConfigurationValidator: input.GlobalConfigurationValidator,
transportServerValidator: input.TransportServerValidator,
virtualServerValidator: input.VirtualServerValidator,
internalRoutesEnabled: input.InternalRoutesEnabled,
isLatencyMetricsEnabled: input.IsLatencyMetricsEnabled,
}
Expand Down Expand Up @@ -1111,7 +1114,7 @@ func (lbc *LoadBalancerController) syncVirtualServer(task task) {
glog.V(2).Infof("Adding or Updating VirtualServer: %v\n", key)
vs := obj.(*conf_v1.VirtualServer)

validationErr := validation.ValidateVirtualServer(vs, lbc.isNginxPlus)
validationErr := lbc.virtualServerValidator.ValidateVirtualServer(vs)
if validationErr != nil {
err := lbc.configurator.DeleteVirtualServer(key)
if err != nil {
Expand Down Expand Up @@ -1275,7 +1278,7 @@ func (lbc *LoadBalancerController) syncVirtualServerRoute(task task) {

vsr := obj.(*conf_v1.VirtualServerRoute)

validationErr := validation.ValidateVirtualServerRoute(vsr, lbc.isNginxPlus)
validationErr := lbc.virtualServerValidator.ValidateVirtualServerRoute(vsr)
if validationErr != nil {
reason := "Rejected"
msg := fmt.Sprintf("VirtualServerRoute %s is invalid and was rejected: %v", key, validationErr)
Expand Down Expand Up @@ -2219,7 +2222,7 @@ func (lbc *LoadBalancerController) getVirtualServers() []*conf_v1.VirtualServer
continue
}

err := validation.ValidateVirtualServer(vs, lbc.isNginxPlus)
err := lbc.virtualServerValidator.ValidateVirtualServer(vs)
if err != nil {
glog.V(3).Infof("Skipping invalid VirtualServer %s/%s: %v", vs.Namespace, vs.Name, err)
continue
Expand All @@ -2242,7 +2245,7 @@ func (lbc *LoadBalancerController) getVirtualServerRoutes() []*conf_v1.VirtualSe
continue
}

err := validation.ValidateVirtualServerRoute(vsr, lbc.isNginxPlus)
err := lbc.virtualServerValidator.ValidateVirtualServerRoute(vsr)
if err != nil {
glog.V(3).Infof("Skipping invalid VirtualServerRoute %s/%s: %v", vsr.Namespace, vsr.Name, err)
continue
Expand Down Expand Up @@ -2703,7 +2706,7 @@ func (lbc *LoadBalancerController) createVirtualServer(virtualServer *conf_v1.Vi
continue
}

err = validation.ValidateVirtualServerRouteForVirtualServer(vsr, virtualServer.Spec.Host, r.Path, lbc.isNginxPlus)
err = lbc.virtualServerValidator.ValidateVirtualServerRouteForVirtualServer(vsr, virtualServer.Spec.Host, r.Path)
if err != nil {
glog.Warningf("VirtualServer %s/%s references invalid VirtualServerRoute %s: %v", virtualServer.Name, virtualServer.Namespace, vsrKey, err)
virtualServerRouteErrors = append(virtualServerRouteErrors, newVirtualServerRouteErrorFromVSR(vsr, err))
Expand Down
71 changes: 54 additions & 17 deletions pkg/apis/configuration/validation/common.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,38 +26,75 @@ func validateVariable(nVar string, validVars map[string]bool, fieldPath *field.P
return allErrs
}

func isValidSpecialVariableHeader(header string) []string {
// underscores in $http_ variable represent '-'.
errMsgs := validation.IsHTTPHeaderName(strings.Replace(header, "_", "-", -1))
if len(errMsgs) >= 1 || strings.Contains(header, "-") {
return []string{"a valid HTTP header must consist of alphanumeric characters or '_'"}
// isValidSpecialHeaderLikeVariable validates special variables $http_, $jwt_header_, $jwt_claim_
func isValidSpecialHeaderLikeVariable(value string) []string {
// underscores in a header-like variable represent '-'.
errMsgs := validation.IsHTTPHeaderName(strings.Replace(value, "_", "-", -1))
if len(errMsgs) >= 1 || strings.Contains(value, "-") {
return []string{"a valid special variable must consists of alphanumeric characters or '_'"}
}
return nil
}

func validateSpecialVariable(nVar string, fieldPath *field.Path) field.ErrorList {
func parseSpecialVariable(nVar string, fieldPath *field.Path) (name string, value string, allErrs field.ErrorList) {
// parse NGINX Plus variables
if strings.HasPrefix(nVar, "jwt_header") || strings.HasPrefix(nVar, "jwt_claim") {
parts := strings.SplitN(nVar, "_", 3)
if len(parts) != 3 {
allErrs = append(allErrs, field.Invalid(fieldPath, nVar, "is invalid variable"))
return name, value, allErrs
}

// ex: jwt_header_name_one -> jwt_header, name_one
return strings.Join(parts[:2], "_"), parts[2], allErrs
}

// parse common NGINX and NGINX Plus variables
parts := strings.SplitN(nVar, "_", 2)
if len(parts) != 2 {
allErrs = append(allErrs, field.Invalid(fieldPath, nVar, "is invalid variable"))
return name, value, allErrs
}

// ex: http_name_one -> http, name_one
return parts[0], parts[1], allErrs
}

func validateSpecialVariable(nVar string, fieldPath *field.Path, isPlus bool) field.ErrorList {
allErrs := field.ErrorList{}
value := strings.SplitN(nVar, "_", 2)

switch value[0] {
case "arg":
for _, msg := range isArgumentName(value[1]) {
name, value, allErrs := parseSpecialVariable(nVar, fieldPath)
if len(allErrs) > 0 {
return allErrs
}

addErrors := func(errors []string) {
for _, msg := range errors {
allErrs = append(allErrs, field.Invalid(fieldPath, nVar, msg))
}
}

switch name {
case "arg":
addErrors(isArgumentName(value))
case "http":
for _, msg := range isValidSpecialVariableHeader(value[1]) {
allErrs = append(allErrs, field.Invalid(fieldPath, nVar, msg))
}
addErrors(isValidSpecialHeaderLikeVariable(value))
case "cookie":
for _, msg := range isCookieName(value[1]) {
allErrs = append(allErrs, field.Invalid(fieldPath, nVar, msg))
addErrors(isCookieName(value))
case "jwt_header", "jwt_claim":
if !isPlus {
allErrs = append(allErrs, field.Forbidden(fieldPath, "is only supported in NGINX Plus"))
} else {
addErrors(isValidSpecialHeaderLikeVariable(value))
}
default:
allErrs = append(allErrs, field.Invalid(fieldPath, nVar, "unknown special variable"))
}

return allErrs
}

func validateStringWithVariables(str string, fieldPath *field.Path, specialVars []string, validVars map[string]bool) field.ErrorList {
func validateStringWithVariables(str string, fieldPath *field.Path, specialVars []string, validVars map[string]bool, isPlus bool) field.ErrorList {
allErrs := field.ErrorList{}

if strings.HasSuffix(str, "$") {
Expand Down Expand Up @@ -89,7 +126,7 @@ func validateStringWithVariables(str string, fieldPath *field.Path, specialVars
}

if special {
allErrs = append(allErrs, validateSpecialVariable(nVar, fieldPath)...)
allErrs = append(allErrs, validateSpecialVariable(nVar, fieldPath, isPlus)...)
} else {
allErrs = append(allErrs, validateVariable(nVar, validVars, fieldPath)...)
}
Expand Down
147 changes: 139 additions & 8 deletions pkg/apis/configuration/validation/common_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,27 +52,156 @@ func TestValidateVariableFails(t *testing.T) {
}
}

func TestParseSpecialVariable(t *testing.T) {
tests := []struct {
specialVar string
expectedName string
expectedValue string
}{
{
specialVar: "arg_username",
expectedName: "arg",
expectedValue: "username",
},
{
specialVar: "arg_user_name",
expectedName: "arg",
expectedValue: "user_name",
},
{
specialVar: "jwt_header_username",
expectedName: "jwt_header",
expectedValue: "username",
},
{
specialVar: "jwt_header_user_name",
expectedName: "jwt_header",
expectedValue: "user_name",
},
{
specialVar: "jwt_claim_username",
expectedName: "jwt_claim",
expectedValue: "username",
},
{
specialVar: "jwt_claim_user_name",
expectedName: "jwt_claim",
expectedValue: "user_name",
},
}

for _, test := range tests {
name, value, allErrs := parseSpecialVariable(test.specialVar, field.NewPath("variable"))
if name != test.expectedName {
t.Errorf("parseSpecialVariable(%v) returned name %v but expected %v", test.specialVar, name, test.expectedName)
}
if value != test.expectedValue {
t.Errorf("parseSpecialVariable(%v) returned value %v but expected %v", test.specialVar, value, test.expectedValue)
}
if len(allErrs) != 0 {
t.Errorf("parseSpecialVariable(%v) returned errors for valid case: %v", test.specialVar, allErrs)
}
}
}

func TestParseSpecialVariableFails(t *testing.T) {
specialVars := []string{
"arg",
"jwt_header",
"jwt_claim",
}

for _, v := range specialVars {
_, _, allErrs := parseSpecialVariable(v, field.NewPath("variable"))
if len(allErrs) == 0 {
t.Errorf("parseSpecialVariable(%v) returned no errors for invalid case", v)
}
}
}

func TestValidateSpecialVariable(t *testing.T) {
specialVars := []string{"arg_username", "arg_user_name", "http_header_name", "cookie_cookie_name"}
specialVars := []string{
"arg_username",
"arg_user_name",
"http_header_name",
"cookie_cookie_name",
}

isPlus := false

for _, v := range specialVars {
allErrs := validateSpecialVariable(v, field.NewPath("variable"))
allErrs := validateSpecialVariable(v, field.NewPath("variable"), isPlus)
if len(allErrs) != 0 {
t.Errorf("validateSpecialVariable(%v) returned errors for valid case: %v", v, allErrs)
}
}
}

func TestValidateSpecialVariableForPlus(t *testing.T) {
specialVars := []string{
"arg_username",
"arg_user_name",
"http_header_name",
"cookie_cookie_name",
"jwt_header_alg",
"jwt_claim_user",
}

isPlus := true

for _, v := range specialVars {
allErrs := validateSpecialVariable(v, field.NewPath("variable"), isPlus)
if len(allErrs) != 0 {
t.Errorf("validateSpecialVariable(%v) returned errors for valid case: %v", v, allErrs)
}
}
}

func TestValidateSpecialVariableFails(t *testing.T) {
specialVars := []string{"arg_invalid%", "http_header+invalid", "cookie_cookie_name?invalid"}
specialVars := []string{
"arg",
"arg_invalid%",
"http_header+invalid",
"cookie_cookie_name?invalid",
"jwt_header_alg",
"jwt_claim_user",
"some_var",
}

isPlus := false

for _, v := range specialVars {
allErrs := validateSpecialVariable(v, field.NewPath("variable"), isPlus)
if len(allErrs) == 0 {
t.Errorf("validateSpecialVariable(%v) returned no errors for invalid case", v)
}
}
}

func TestValidateSpecialVariableForPlusFails(t *testing.T) {
specialVars := []string{
"arg",
"arg_invalid%",
"http_header+invalid",
"cookie_cookie_name?invalid",
"jwt_header_+invalid",
"wt_claim_invalid?",
"some_var",
}

isPlus := true

for _, v := range specialVars {
allErrs := validateSpecialVariable(v, field.NewPath("variable"))
allErrs := validateSpecialVariable(v, field.NewPath("variable"), isPlus)
if len(allErrs) == 0 {
t.Errorf("validateSpecialVariable(%v) returned no errors for invalid case", v)
}
}
}

func TestValidateStringWithVariables(t *testing.T) {
isPlus := false

testStrings := []string{
"",
"${scheme}",
Expand All @@ -82,7 +211,7 @@ func TestValidateStringWithVariables(t *testing.T) {
validVars := map[string]bool{"scheme": true, "host": true}

for _, test := range testStrings {
allErrs := validateStringWithVariables(test, field.NewPath("string"), nil, validVars)
allErrs := validateStringWithVariables(test, field.NewPath("string"), nil, validVars, isPlus)
if len(allErrs) != 0 {
t.Errorf("validateStringWithVariables(%v) returned errors for valid input: %v", test, allErrs)
}
Expand All @@ -96,14 +225,16 @@ func TestValidateStringWithVariables(t *testing.T) {
}

for _, test := range testStringsSpecial {
allErrs := validateStringWithVariables(test, field.NewPath("string"), specialVars, validVars)
allErrs := validateStringWithVariables(test, field.NewPath("string"), specialVars, validVars, isPlus)
if len(allErrs) != 0 {
t.Errorf("validateStringWithVariables(%v) returned errors for valid input: %v", test, allErrs)
}
}
}

func TestValidateStringWithVariablesFail(t *testing.T) {
isPlus := false

testStrings := []string{
"$scheme}",
"${sch${eme}${host}",
Expand All @@ -114,7 +245,7 @@ func TestValidateStringWithVariablesFail(t *testing.T) {
validVars := map[string]bool{"scheme": true, "host": true}

for _, test := range testStrings {
allErrs := validateStringWithVariables(test, field.NewPath("string"), nil, validVars)
allErrs := validateStringWithVariables(test, field.NewPath("string"), nil, validVars, isPlus)
if len(allErrs) == 0 {
t.Errorf("validateStringWithVariables(%v) returned no errors for invalid input", test)
}
Expand All @@ -128,7 +259,7 @@ func TestValidateStringWithVariablesFail(t *testing.T) {
}

for _, test := range testStringsSpecial {
allErrs := validateStringWithVariables(test, field.NewPath("string"), specialVars, validVars)
allErrs := validateStringWithVariables(test, field.NewPath("string"), specialVars, validVars, isPlus)
if len(allErrs) == 0 {
t.Errorf("validateStringWithVariables(%v) returned no errors for invalid input", test)
}
Expand Down
Loading