From 585b8e07ff512d00f449b727aa69cb1abf382fcf Mon Sep 17 00:00:00 2001 From: Sean Yang Date: Wed, 25 Sep 2024 23:32:05 +0800 Subject: [PATCH] Support configuration of disallowed content-types --- .../dubbo/config/nested/RestConfig.java | 13 ++++++++ .../protocol/tri/servlet/TripleFilter.java | 4 ++- .../http12/message/codec/CodecUtils.java | 30 ++++++++++++++----- .../java/org/apache/dubbo/rpc/Constants.java | 1 + .../Http1UnaryServerChannelObserver.java | 15 +++++++++- 5 files changed, 53 insertions(+), 10 deletions(-) diff --git a/dubbo-common/src/main/java/org/apache/dubbo/config/nested/RestConfig.java b/dubbo-common/src/main/java/org/apache/dubbo/config/nested/RestConfig.java index 4e8d6a54eff..be813c68e59 100644 --- a/dubbo-common/src/main/java/org/apache/dubbo/config/nested/RestConfig.java +++ b/dubbo-common/src/main/java/org/apache/dubbo/config/nested/RestConfig.java @@ -67,6 +67,11 @@ public class RestConfig implements Serializable { */ private String jsonFramework; + /** + * The disallowed content-types. + */ + private String[] disallowedContentTypes; + /** * The cors configuration. */ @@ -133,6 +138,14 @@ public void setJsonFramework(String jsonFramework) { this.jsonFramework = jsonFramework; } + public String[] getDisallowedContentTypes() { + return disallowedContentTypes; + } + + public void setDisallowedContentTypes(String[] disallowedContentTypes) { + this.disallowedContentTypes = disallowedContentTypes; + } + public CorsConfig getCors() { return cors; } diff --git a/dubbo-plugin/dubbo-triple-servlet/src/main/java/org/apache/dubbo/rpc/protocol/tri/servlet/TripleFilter.java b/dubbo-plugin/dubbo-triple-servlet/src/main/java/org/apache/dubbo/rpc/protocol/tri/servlet/TripleFilter.java index 85175d3d24e..891cc244409 100644 --- a/dubbo-plugin/dubbo-triple-servlet/src/main/java/org/apache/dubbo/rpc/protocol/tri/servlet/TripleFilter.java +++ b/dubbo-plugin/dubbo-triple-servlet/src/main/java/org/apache/dubbo/rpc/protocol/tri/servlet/TripleFilter.java @@ -125,8 +125,10 @@ private void handleHttp1(HttpServletRequest request, HttpServletResponse respons channel, ServletExchanger.getUrl(), FrameworkModel.defaultModel()); channel.setGrpc(false); context.setTimeout(resolveTimeout(request, false)); - listener.onMetadata(new HttpMetadataAdapter(request)); ServletInputStream is = request.getInputStream(); + response.getOutputStream().setWriteListener(new TripleWriteListener(channel)); + + listener.onMetadata(new HttpMetadataAdapter(request)); listener.onData(new Http1InputMessage( is.available() == 0 ? StreamUtils.EMPTY : new ByteArrayInputStream(StreamUtils.readBytes(is)))); } catch (Throwable t) { diff --git a/dubbo-remoting/dubbo-remoting-http12/src/main/java/org/apache/dubbo/remoting/http12/message/codec/CodecUtils.java b/dubbo-remoting/dubbo-remoting-http12/src/main/java/org/apache/dubbo/remoting/http12/message/codec/CodecUtils.java index 2b63feb1b09..143857e5fa0 100644 --- a/dubbo-remoting/dubbo-remoting-http12/src/main/java/org/apache/dubbo/remoting/http12/message/codec/CodecUtils.java +++ b/dubbo-remoting/dubbo-remoting-http12/src/main/java/org/apache/dubbo/remoting/http12/message/codec/CodecUtils.java @@ -17,17 +17,24 @@ package org.apache.dubbo.remoting.http12.message.codec; import org.apache.dubbo.common.URL; +import org.apache.dubbo.common.config.Configuration; +import org.apache.dubbo.common.config.ConfigurationUtils; import org.apache.dubbo.common.utils.Assert; +import org.apache.dubbo.common.utils.StringUtils; import org.apache.dubbo.remoting.http12.exception.UnsupportedMediaTypeException; import org.apache.dubbo.remoting.http12.message.HttpMessageDecoder; import org.apache.dubbo.remoting.http12.message.HttpMessageDecoderFactory; import org.apache.dubbo.remoting.http12.message.HttpMessageEncoder; import org.apache.dubbo.remoting.http12.message.HttpMessageEncoderFactory; +import org.apache.dubbo.rpc.Constants; import org.apache.dubbo.rpc.model.FrameworkModel; +import java.util.Collections; +import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.Optional; +import java.util.Set; import java.util.concurrent.ConcurrentHashMap; public final class CodecUtils { @@ -37,13 +44,18 @@ public final class CodecUtils { private final List encoderFactories; private final Map> encoderCache = new ConcurrentHashMap<>(); private final Map> decoderCache = new ConcurrentHashMap<>(); + private Set disallowedContentTypes = Collections.emptySet(); public CodecUtils(FrameworkModel frameworkModel) { this.frameworkModel = frameworkModel; decoderFactories = frameworkModel.getActivateExtensions(HttpMessageDecoderFactory.class); encoderFactories = frameworkModel.getActivateExtensions(HttpMessageEncoderFactory.class); - decoderFactories.forEach(f -> decoderCache.putIfAbsent(f.mediaType().getName(), Optional.of(f))); - encoderFactories.forEach(f -> encoderCache.putIfAbsent(f.mediaType().getName(), Optional.of(f))); + + Configuration configuration = ConfigurationUtils.getGlobalConfiguration(frameworkModel.defaultApplication()); + String contentTypes = configuration.getString(Constants.H2_SETTINGS_DISALLOWED_CONTENT_TYPES, null); + if (contentTypes != null) { + disallowedContentTypes = new HashSet<>(StringUtils.tokenizeToList(contentTypes)); + } } public HttpMessageDecoder determineHttpMessageDecoder(URL url, String mediaType) { @@ -69,9 +81,10 @@ public HttpMessageEncoder determineHttpMessageEncoder(String mediaType) { public Optional determineHttpMessageDecoderFactory(String mediaType) { Assert.notNull(mediaType, "mediaType must not be null"); return decoderCache.computeIfAbsent(mediaType, k -> { - for (HttpMessageDecoderFactory decoderFactory : decoderFactories) { - if (decoderFactory.supports(k)) { - return Optional.of(decoderFactory); + for (HttpMessageDecoderFactory factory : decoderFactories) { + if (factory.supports(k) + && !disallowedContentTypes.contains(factory.mediaType().getName())) { + return Optional.of(factory); } } return Optional.empty(); @@ -81,9 +94,10 @@ public Optional determineHttpMessageDecoderFactory(St public Optional determineHttpMessageEncoderFactory(String mediaType) { Assert.notNull(mediaType, "mediaType must not be null"); return encoderCache.computeIfAbsent(mediaType, k -> { - for (HttpMessageEncoderFactory encoderFactory : encoderFactories) { - if (encoderFactory.supports(k)) { - return Optional.of(encoderFactory); + for (HttpMessageEncoderFactory factory : encoderFactories) { + if (factory.supports(k) + && !disallowedContentTypes.contains(factory.mediaType().getName())) { + return Optional.of(factory); } } return Optional.empty(); diff --git a/dubbo-rpc/dubbo-rpc-api/src/main/java/org/apache/dubbo/rpc/Constants.java b/dubbo-rpc/dubbo-rpc-api/src/main/java/org/apache/dubbo/rpc/Constants.java index c5b4bdb18a8..6fee3176686 100644 --- a/dubbo-rpc/dubbo-rpc-api/src/main/java/org/apache/dubbo/rpc/Constants.java +++ b/dubbo-rpc/dubbo-rpc-api/src/main/java/org/apache/dubbo/rpc/Constants.java @@ -109,6 +109,7 @@ public interface Constants { String H2_SETTINGS_BUILTIN_SERVICE_INIT = "dubbo.tri.builtin.service.init"; String H2_SETTINGS_JSON_FRAMEWORK_NAME = "dubbo.protocol.triple.rest.json-framework"; + String H2_SETTINGS_DISALLOWED_CONTENT_TYPES = "dubbo.protocol.triple.rest.disallowed-content-types"; String H2_SETTINGS_VERBOSE_ENABLED = "dubbo.protocol.triple.verbose"; String H2_SETTINGS_SERVLET_ENABLED = "dubbo.protocol.triple.servlet.enabled"; diff --git a/dubbo-rpc/dubbo-rpc-triple/src/main/java/org/apache/dubbo/rpc/protocol/tri/h12/http1/Http1UnaryServerChannelObserver.java b/dubbo-rpc/dubbo-rpc-triple/src/main/java/org/apache/dubbo/rpc/protocol/tri/h12/http1/Http1UnaryServerChannelObserver.java index eff04e89072..5b543957a27 100644 --- a/dubbo-rpc/dubbo-rpc-triple/src/main/java/org/apache/dubbo/rpc/protocol/tri/h12/http1/Http1UnaryServerChannelObserver.java +++ b/dubbo-rpc/dubbo-rpc-triple/src/main/java/org/apache/dubbo/rpc/protocol/tri/h12/http1/Http1UnaryServerChannelObserver.java @@ -24,6 +24,9 @@ import org.apache.dubbo.rpc.protocol.tri.ExceptionUtils; import org.apache.dubbo.rpc.protocol.tri.TripleProtocol; +import java.io.ByteArrayOutputStream; +import java.io.OutputStream; + import io.netty.buffer.ByteBufOutputStream; public final class Http1UnaryServerChannelObserver extends Http1ServerChannelObserver { @@ -52,7 +55,17 @@ protected void doOnError(Throwable throwable) throws Throwable { @Override protected void customizeHeaders(HttpHeaders headers, Throwable throwable, HttpOutputMessage message) { super.customizeHeaders(headers, throwable, message); - int contentLength = message == null ? 0 : ((ByteBufOutputStream) message.getBody()).writtenBytes(); + int contentLength = 0; + if (message != null) { + OutputStream body = message.getBody(); + if (body instanceof ByteBufOutputStream) { + contentLength = ((ByteBufOutputStream) body).writtenBytes(); + } else if (body instanceof ByteArrayOutputStream) { + contentLength = ((ByteArrayOutputStream) body).size(); + } else { + throw new IllegalArgumentException("Unsupported body type: " + body.getClass()); + } + } headers.set(HttpHeaderNames.CONTENT_LENGTH.getName(), String.valueOf(contentLength)); }