Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
poutsma committed Sep 22, 2021
1 parent 8f96ca4 commit 345befd
Show file tree
Hide file tree
Showing 7 changed files with 255 additions and 190 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,21 @@

package org.springframework.web.bind;

import java.lang.reflect.Constructor;
import java.util.List;

import jakarta.servlet.ServletRequest;
import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.Part;

import org.springframework.beans.MutablePropertyValues;
import org.springframework.http.HttpMethod;
import org.springframework.http.MediaType;
import org.springframework.core.MethodParameter;
import org.springframework.lang.Nullable;
import org.springframework.util.StringUtils;
import org.springframework.validation.BindException;
import org.springframework.web.multipart.MultipartFile;
import org.springframework.web.multipart.MultipartRequest;
import org.springframework.web.multipart.support.StandardServletPartUtils;
import org.springframework.web.util.WebUtils;
Expand Down Expand Up @@ -142,4 +148,37 @@ public void closeNoCatch() throws ServletRequestBindingException {
}
}

public <T> T construct(ServletRequest request, Constructor<T> ctor, Callback callback, @Nullable MethodParameter parameter) throws Exception {
return super.construct(ctor, (name, type) -> getBindValue(request, name, type), callback, parameter);
}

@Nullable
protected Object getBindValue(ServletRequest request, String name, Class<?> type) {
Object value = request.getParameterValues(name);
if (value != null) {
return value;
}
else {
MultipartRequest multipartRequest = WebUtils.getNativeRequest(request, MultipartRequest.class);
if (multipartRequest != null) {
List<MultipartFile> files = multipartRequest.getFiles(name);
if (!files.isEmpty()) {
return (files.size() == 1 ? files.get(0) : files);
}
}
else if (StringUtils.startsWithIgnoreCase(request.getContentType(), "multipart/")) {
HttpServletRequest httpServletRequest = WebUtils.getNativeRequest(request, HttpServletRequest.class);
if (httpServletRequest != null) {
List<Part> parts = StandardServletPartUtils.getParts(httpServletRequest, name);
if (!parts.isEmpty()) {
return (parts.size() == 1 ? parts.get(0) : parts);
}
}
}
}
return null;
}



}
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2002-2020 the original author or authors.
* Copyright 2002-2021 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -16,15 +16,31 @@

package org.springframework.web.bind;

import java.lang.annotation.Annotation;
import java.lang.reflect.Array;
import java.lang.reflect.Constructor;
import java.lang.reflect.Field;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.function.BiFunction;

import org.springframework.beans.BeanInstantiationException;
import org.springframework.beans.BeanUtils;
import org.springframework.beans.MutablePropertyValues;
import org.springframework.beans.PropertyValue;
import org.springframework.beans.TypeMismatchException;
import org.springframework.core.CollectionFactory;
import org.springframework.core.MethodParameter;
import org.springframework.lang.Nullable;
import org.springframework.util.ObjectUtils;
import org.springframework.validation.BindException;
import org.springframework.validation.BindingResult;
import org.springframework.validation.DataBinder;
import org.springframework.web.multipart.MultipartFile;

Expand Down Expand Up @@ -350,4 +366,150 @@ protected void bindMultipart(Map<String, List<MultipartFile>> multipartFiles, Mu
});
}

@SuppressWarnings("serial")
protected <T> T construct(Constructor<T> ctor, BiFunction<String, Class<?>, Object> values,
@Nullable Callback callback, @Nullable MethodParameter parameter) throws Exception {

// A single data class constructor -> resolve constructor arguments from request parameters.
String[] paramNames = BeanUtils.getParameterNames(ctor);
Class<?>[] paramTypes = ctor.getParameterTypes();
Object[] args = new Object[paramTypes.length];
String fieldDefaultPrefix = getFieldDefaultPrefix();
String fieldMarkerPrefix = getFieldMarkerPrefix();
boolean bindingFailure = false;
Set<String> failedParams = new HashSet<>(4);

for (int i = 0; i < paramNames.length; i++) {
String paramName = paramNames[i];
Class<?> paramType = paramTypes[i];
Object value = values.apply(paramName, paramType);

if (ObjectUtils.isArray(value) && Array.getLength(value) == 1) {
value = Array.get(value, 0);
}

if (value == null) {
if (fieldDefaultPrefix != null) {
value = values.apply(fieldDefaultPrefix + paramName, paramType);
}
if (value == null) {
if (fieldMarkerPrefix != null &&
values.apply(fieldMarkerPrefix + paramName, paramType) != null) {
value = getEmptyValue(paramType);
}
}
}
try {
MethodParameter methodParam = new FieldAwareConstructorParameter(ctor, i, paramName);
if (value == null && methodParam.isOptional()) {
args[i] = (methodParam.getParameterType() == Optional.class ? Optional.empty() : null);
}
else {
args[i] = convertIfNecessary(value, paramType, methodParam);
}
}
catch (TypeMismatchException ex) {
ex.initPropertyName(paramName);
args[i] = null;
failedParams.add(paramName);
getBindingResult().recordFieldValue(paramName, paramType, value);
getBindingErrorProcessor().processPropertyAccessException(ex, getBindingResult());
bindingFailure = true;
}
}

if (bindingFailure) {
BindingResult result = getBindingResult();
for (int i = 0; i < paramNames.length; i++) {
String paramName = paramNames[i];
if (!failedParams.contains(paramName)) {
Object value = args[i];
result.recordFieldValue(paramName, paramTypes[i], value);
if (parameter != null && callback != null) {
callback.validateValue(this, parameter, ctor.getDeclaringClass(), paramName, value);
}
}
}
if (parameter != null && !parameter.isOptional()) {
try {
Object target = BeanUtils.instantiateClass(ctor, args);
throw new BindException(result) {
@Override
public Object getTarget() {
return target;
}
};
}
catch (BeanInstantiationException ex) {
// swallow and proceed without target instance
}
}
throw new BindException(result);
}

return BeanUtils.instantiateClass(ctor, args);
}

public interface Callback {

void validateValue(WebDataBinder dataBinder, MethodParameter parameter, Class<?> declaringClass, String paramName, Object value);
}


/**
* {@link MethodParameter} subclass which detects field annotations as well.
*/
private static class FieldAwareConstructorParameter extends MethodParameter {

private final String parameterName;

@Nullable
private volatile Annotation[] combinedAnnotations;

public FieldAwareConstructorParameter(Constructor<?> constructor, int parameterIndex, String parameterName) {
super(constructor, parameterIndex);
this.parameterName = parameterName;
}

@Override
public Annotation[] getParameterAnnotations() {
Annotation[] anns = this.combinedAnnotations;
if (anns == null) {
anns = super.getParameterAnnotations();
try {
Field field = getDeclaringClass().getDeclaredField(this.parameterName);
Annotation[] fieldAnns = field.getAnnotations();
if (fieldAnns.length > 0) {
List<Annotation> merged = new ArrayList<>(anns.length + fieldAnns.length);
merged.addAll(Arrays.asList(anns));
for (Annotation fieldAnn : fieldAnns) {
boolean existingType = false;
for (Annotation ann : anns) {
if (ann.annotationType() == fieldAnn.annotationType()) {
existingType = true;
break;
}
}
if (!existingType) {
merged.add(fieldAnn);
}
}
anns = merged.toArray(new Annotation[0]);
}
}
catch (NoSuchFieldException | SecurityException ex) {
// ignore
}
this.combinedAnnotations = anns;
}
return anns;
}

@Override
public String getParameterName() {
return this.parameterName;
}
}


}
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

package org.springframework.web.bind.support;

import java.lang.reflect.Constructor;
import java.util.List;
import java.util.Map;
import java.util.TreeMap;
Expand All @@ -24,6 +25,7 @@
import reactor.core.publisher.Mono;

import org.springframework.beans.MutablePropertyValues;
import org.springframework.core.MethodParameter;
import org.springframework.http.codec.multipart.FormFieldPart;
import org.springframework.http.codec.multipart.Part;
import org.springframework.lang.Nullable;
Expand Down Expand Up @@ -85,6 +87,19 @@ public Mono<Map<String, Object>> getValuesToBind(ServerWebExchange exchange) {
return extractValuesToBind(exchange);
}

public <T> Mono<T> construct(ServerWebExchange exchange, Constructor<T> ctor,
@Nullable MethodParameter parameter) {
return getValuesToBind(exchange).flatMap(bindValues -> {
try {
return Mono.just(super.construct(ctor, (name, type) -> bindValues.get(name), null, parameter));
}
catch (Exception ex) {
return Mono.error(ex);
}
});
}



/**
* Combine query params and form data for multipart form data from the body
Expand Down
Loading

0 comments on commit 345befd

Please sign in to comment.