Skip to content

Commit

Permalink
Add default web handling of method validation errors
Browse files Browse the repository at this point in the history
Closes gh-30644
  • Loading branch information
rstoyanchev committed Jul 3, 2023
1 parent a481c76 commit 7a79da5
Show file tree
Hide file tree
Showing 10 changed files with 338 additions and 62 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
/*
* Copyright 2002-2023 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.springframework.web.method.annotation;

import java.lang.reflect.Method;
import java.util.List;
import java.util.Locale;

import org.springframework.context.MessageSource;
import org.springframework.http.HttpStatus;
import org.springframework.validation.beanvalidation.MethodValidationResult;
import org.springframework.validation.beanvalidation.ParameterValidationResult;
import org.springframework.web.server.ResponseStatusException;
import org.springframework.web.util.BindErrorUtils;

/**
* {@link ResponseStatusException} that is also {@link MethodValidationResult}.
* Raised by {@link HandlerMethodValidator} in case of method validation errors
* on a web controller method.
*
* <p>The {@link #getStatusCode()} is 400 for input validation errors, and 500
* for validation errors on a return value.
*
* @author Rossen Stoyanchev
* @since 6.1
*/
@SuppressWarnings("serial")
public class HandlerMethodValidationException extends ResponseStatusException implements MethodValidationResult {

private final MethodValidationResult validationResult;


public HandlerMethodValidationException(MethodValidationResult validationResult) {
super(initHttpStatus(validationResult), "Validation failure", null, null, null);
this.validationResult = validationResult;
}

private static HttpStatus initHttpStatus(MethodValidationResult validationResult) {
return (!validationResult.isForReturnValue() ? HttpStatus.BAD_REQUEST : HttpStatus.INTERNAL_SERVER_ERROR);
}


@Override
public Object[] getDetailMessageArguments(MessageSource messageSource, Locale locale) {
return new Object[] { BindErrorUtils.resolveAndJoin(getAllErrors(), messageSource, locale) };
}

@Override
public Object[] getDetailMessageArguments() {
return new Object[] { BindErrorUtils.resolveAndJoin(getAllErrors()) };
}

@Override
public Object getTarget() {
return this.validationResult.getTarget();
}

@Override
public Method getMethod() {
return this.validationResult.getMethod();
}

@Override
public boolean isForReturnValue() {
return this.validationResult.isForReturnValue();
}

@Override
public List<ParameterValidationResult> getAllValidationResults() {
return this.validationResult.getAllValidationResults();
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
import org.springframework.validation.BindingResult;
import org.springframework.validation.MessageCodesResolver;
import org.springframework.validation.beanvalidation.MethodValidationAdapter;
import org.springframework.validation.beanvalidation.MethodValidationException;
import org.springframework.validation.beanvalidation.MethodValidationResult;
import org.springframework.validation.beanvalidation.MethodValidator;
import org.springframework.validation.beanvalidation.ParameterErrors;
Expand Down Expand Up @@ -91,7 +90,7 @@ public void applyArgumentValidation(
}
}

throw new MethodValidationException(result);
throw new HandlerMethodValidationException(result);
}

@Override
Expand All @@ -109,7 +108,7 @@ public void applyReturnValueValidation(

MethodValidationResult result = validateReturnValue(target, method, returnType, returnValue, groups);
if (result.hasErrors()) {
throw new MethodValidationException(result);
throw new HandlerMethodValidationException(result);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,11 @@
import java.util.Collections;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.function.BiFunction;

import org.junit.jupiter.api.Test;

import org.springframework.beans.testfixture.beans.TestBean;
import org.springframework.context.MessageSource;
import org.springframework.context.MessageSourceResolvable;
import org.springframework.context.support.StaticMessageSource;
import org.springframework.core.MethodParameter;
import org.springframework.http.HttpHeaders;
Expand All @@ -36,9 +34,9 @@
import org.springframework.http.ProblemDetail;
import org.springframework.lang.Nullable;
import org.springframework.util.LinkedMultiValueMap;
import org.springframework.validation.BindException;
import org.springframework.validation.BeanPropertyBindingResult;
import org.springframework.validation.BindingResult;
import org.springframework.validation.ObjectError;
import org.springframework.validation.beanvalidation.MethodValidationResult;
import org.springframework.web.bind.MethodArgumentNotValidException;
import org.springframework.web.bind.MissingMatrixVariableException;
import org.springframework.web.bind.MissingPathVariableException;
Expand All @@ -48,6 +46,7 @@
import org.springframework.web.bind.UnsatisfiedServletRequestParameterException;
import org.springframework.web.bind.support.WebExchangeBindException;
import org.springframework.web.context.request.async.AsyncRequestTimeoutException;
import org.springframework.web.method.annotation.HandlerMethodValidationException;
import org.springframework.web.multipart.support.MissingServletRequestPartException;
import org.springframework.web.server.MethodNotAllowedException;
import org.springframework.web.server.MissingRequestValueException;
Expand All @@ -59,6 +58,9 @@
import org.springframework.web.util.BindErrorUtils;

import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.BDDMockito.mock;
import static org.mockito.BDDMockito.reset;
import static org.mockito.BDDMockito.when;

/**
* Unit tests that verify the HTTP response details exposed by exceptions in the
Expand Down Expand Up @@ -245,20 +247,35 @@ void missingServletRequestPartException() {
@Test
void methodArgumentNotValidException() {

MessageSourceTestHelper messageSourceHelper = new MessageSourceTestHelper(MethodArgumentNotValidException.class);
BindingResult bindingResult = messageSourceHelper.initBindingResult();
ValidationTestHelper testHelper = new ValidationTestHelper(MethodArgumentNotValidException.class);
BindingResult result = testHelper.bindingResult();

MethodArgumentNotValidException ex = new MethodArgumentNotValidException(this.methodParameter, bindingResult);
MethodArgumentNotValidException ex = new MethodArgumentNotValidException(this.methodParameter, result);

assertStatus(ex, HttpStatus.BAD_REQUEST);
assertDetail(ex, "Invalid request content.");
messageSourceHelper.assertDetailMessage(ex);
messageSourceHelper.assertErrorMessages(
(source, locale) -> BindErrorUtils.resolve(ex.getAllErrors(), source, locale));
testHelper.assertMessages(ex, ex.getAllErrors());

assertThat(ex.getHeaders()).isEmpty();
}

@Test
void handlerMethodValidationException() {
MethodValidationResult result = mock(MethodValidationResult.class);
when(result.isForReturnValue()).thenReturn(false);
HandlerMethodValidationException ex = new HandlerMethodValidationException(result);

assertStatus(ex, HttpStatus.BAD_REQUEST);
assertDetail(ex, "Validation failure");

reset(result);
when(result.isForReturnValue()).thenReturn(true);
ex = new HandlerMethodValidationException(result);

assertStatus(ex, HttpStatus.INTERNAL_SERVER_ERROR);
assertDetail(ex, "Validation failure");
}

@Test
void unsupportedMediaTypeStatusException() {

Expand Down Expand Up @@ -360,15 +377,14 @@ void unsatisfiedRequestParameterException() {
@Test
void webExchangeBindException() {

MessageSourceTestHelper messageSourceHelper = new MessageSourceTestHelper(WebExchangeBindException.class);
BindingResult bindingResult = messageSourceHelper.initBindingResult();
ValidationTestHelper testHelper = new ValidationTestHelper(WebExchangeBindException.class);
BindingResult result = testHelper.bindingResult();

WebExchangeBindException ex = new WebExchangeBindException(this.methodParameter, bindingResult);
WebExchangeBindException ex = new WebExchangeBindException(this.methodParameter, result);

assertStatus(ex, HttpStatus.BAD_REQUEST);
assertDetail(ex, "Invalid request content.");
messageSourceHelper.assertDetailMessage(ex);
messageSourceHelper.assertErrorMessages(ex::resolveErrorMessages);
testHelper.assertMessages(ex, ex.getAllErrors());

assertThat(ex.getHeaders()).isEmpty();
}
Expand Down Expand Up @@ -434,59 +450,52 @@ private void assertDetailMessageCode(
private void handle(String arg) {}


private static class MessageSourceTestHelper {
private static class ValidationTestHelper {

private final String code;
private final BindingResult bindingResult;

public MessageSourceTestHelper(Class<? extends ErrorResponse> exceptionType) {
this.code = "problemDetail." + exceptionType.getName();
}
private final StaticMessageSource messageSource = new StaticMessageSource();

public ValidationTestHelper(Class<? extends ErrorResponse> exceptionType) {

public BindingResult initBindingResult() {
BindingResult bindingResult = new BindException(new TestBean(), "myBean");
bindingResult.reject("bean.invalid.A", "Invalid bean message");
bindingResult.reject("bean.invalid.B");
bindingResult.rejectValue("name", "name.required", "must be provided");
bindingResult.rejectValue("age", "age.min");
return bindingResult;
this.bindingResult = new BeanPropertyBindingResult(new TestBean(), "myBean");
this.bindingResult.reject("bean.invalid.A", "Invalid bean message");
this.bindingResult.reject("bean.invalid.B");
this.bindingResult.rejectValue("name", "name.required", "must be provided");
this.bindingResult.rejectValue("age", "age.min");

String code = "problemDetail." + exceptionType.getName();
this.messageSource.addMessage(code, Locale.UK, "Failed because {0}. Also because {1}");
this.messageSource.addMessage("bean.invalid.A", Locale.UK, "Bean A message");
this.messageSource.addMessage("bean.invalid.B", Locale.UK, "Bean B message");
this.messageSource.addMessage("name.required", Locale.UK, "name is required");
this.messageSource.addMessage("age.min", Locale.UK, "age is below minimum");
}

private void assertDetailMessage(ErrorResponse ex) {
public BindingResult bindingResult() {
return this.bindingResult;
}

StaticMessageSource messageSource = initMessageSource();
private void assertMessages(ErrorResponse ex, List<? extends MessageSourceResolvable> errors) {

String message = messageSource.getMessage(
String message = this.messageSource.getMessage(
ex.getDetailMessageCode(), ex.getDetailMessageArguments(), Locale.UK);

assertThat(message).isEqualTo(
"Failed because Invalid bean message, and bean.invalid.B.myBean. " +
"Also because name: must be provided, and age: age.min.myBean.age");

message = messageSource.getMessage(
ex.getDetailMessageCode(), ex.getDetailMessageArguments(messageSource, Locale.UK), Locale.UK);
message = this.messageSource.getMessage(
ex.getDetailMessageCode(), ex.getDetailMessageArguments(this.messageSource, Locale.UK), Locale.UK);

assertThat(message).isEqualTo(
"Failed because Bean A message, and Bean B message. " +
"Also because name is required, and age is below minimum");
}

private void assertErrorMessages(BiFunction<MessageSource, Locale, Map<ObjectError, String>> expectedMessages) {
StaticMessageSource messageSource = initMessageSource();
Map<ObjectError, String> map = expectedMessages.apply(messageSource, Locale.UK);

assertThat(map).hasSize(4).containsValues(
"Bean A message", "Bean B message", "name is required", "age is below minimum");
assertThat(BindErrorUtils.resolve(errors, this.messageSource, Locale.UK)).hasSize(4)
.containsValues("Bean A message", "Bean B message", "name is required", "age is below minimum");
}

private StaticMessageSource initMessageSource() {
StaticMessageSource messageSource = new StaticMessageSource();
messageSource.addMessage(this.code, Locale.UK, "Failed because {0}. Also because {1}");
messageSource.addMessage("bean.invalid.A", Locale.UK, "Bean A message");
messageSource.addMessage("bean.invalid.B", Locale.UK, "Bean B message");
messageSource.addMessage("name.required", Locale.UK, "name is required");
messageSource.addMessage("age.min", Locale.UK, "age is below minimum");
return messageSource;
}
}

}
Loading

0 comments on commit 7a79da5

Please sign in to comment.