From d0ced8825a9bd5e25a7feb7a16df3d277cb88589 Mon Sep 17 00:00:00 2001 From: Raman Gupta Date: Fri, 22 Nov 2024 05:50:25 -0500 Subject: [PATCH] KTOR-7722 content negotiation client accept header control (#4462) --- .../api/ktor-client-content-negotiation.api | 3 + .../ktor-client-content-negotiation.klib.api | 6 + .../contentnegotiation/ContentNegotiation.kt | 50 +++++- .../client/plugins/ContentNegotiationTests.kt | 160 ++++++++++++++++++ .../common/src/io/ktor/http/ContentTypes.kt | 8 + 5 files changed, 221 insertions(+), 6 deletions(-) diff --git a/ktor-client/ktor-client-plugins/ktor-client-content-negotiation/api/ktor-client-content-negotiation.api b/ktor-client/ktor-client-plugins/ktor-client-content-negotiation/api/ktor-client-content-negotiation.api index ad13c0a333b..912d384413f 100644 --- a/ktor-client/ktor-client-plugins/ktor-client-content-negotiation/api/ktor-client-content-negotiation.api +++ b/ktor-client/ktor-client-plugins/ktor-client-content-negotiation/api/ktor-client-content-negotiation.api @@ -5,13 +5,16 @@ public final class io/ktor/client/plugins/contentnegotiation/ContentConverterExc public final class io/ktor/client/plugins/contentnegotiation/ContentNegotiationConfig : io/ktor/serialization/Configuration { public fun ()V public final fun clearIgnoredTypes ()V + public final fun getDefaultAcceptHeaderQValue ()Ljava/lang/Double; public final fun ignoreType (Lkotlin/reflect/KClass;)V public final fun register (Lio/ktor/http/ContentType;Lio/ktor/serialization/ContentConverter;Lio/ktor/http/ContentTypeMatcher;Lkotlin/jvm/functions/Function1;)V public fun register (Lio/ktor/http/ContentType;Lio/ktor/serialization/ContentConverter;Lkotlin/jvm/functions/Function1;)V public final fun removeIgnoredType (Lkotlin/reflect/KClass;)V + public final fun setDefaultAcceptHeaderQValue (Ljava/lang/Double;)V } public final class io/ktor/client/plugins/contentnegotiation/ContentNegotiationKt { + public static final fun exclude (Lio/ktor/client/request/HttpRequestBuilder;[Lio/ktor/http/ContentType;)V public static final fun getContentNegotiation ()Lio/ktor/client/plugins/api/ClientPlugin; } diff --git a/ktor-client/ktor-client-plugins/ktor-client-content-negotiation/api/ktor-client-content-negotiation.klib.api b/ktor-client/ktor-client-plugins/ktor-client-content-negotiation/api/ktor-client-content-negotiation.klib.api index 06315ef5525..2638ca4ba91 100644 --- a/ktor-client/ktor-client-plugins/ktor-client-content-negotiation/api/ktor-client-content-negotiation.klib.api +++ b/ktor-client/ktor-client-plugins/ktor-client-content-negotiation/api/ktor-client-content-negotiation.klib.api @@ -13,6 +13,10 @@ final class io.ktor.client.plugins.contentnegotiation/ContentConverterException final class io.ktor.client.plugins.contentnegotiation/ContentNegotiationConfig : io.ktor.serialization/Configuration { // io.ktor.client.plugins.contentnegotiation/ContentNegotiationConfig|null[0] constructor () // io.ktor.client.plugins.contentnegotiation/ContentNegotiationConfig.|(){}[0] + final var defaultAcceptHeaderQValue // io.ktor.client.plugins.contentnegotiation/ContentNegotiationConfig.defaultAcceptHeaderQValue|{}defaultAcceptHeaderQValue[0] + final fun (): kotlin/Double? // io.ktor.client.plugins.contentnegotiation/ContentNegotiationConfig.defaultAcceptHeaderQValue.|(){}[0] + final fun (kotlin/Double?) // io.ktor.client.plugins.contentnegotiation/ContentNegotiationConfig.defaultAcceptHeaderQValue.|(kotlin.Double?){}[0] + final fun <#A1: io.ktor.serialization/ContentConverter> register(io.ktor.http/ContentType, #A1, io.ktor.http/ContentTypeMatcher, kotlin/Function1<#A1, kotlin/Unit>) // io.ktor.client.plugins.contentnegotiation/ContentNegotiationConfig.register|register(io.ktor.http.ContentType;0:0;io.ktor.http.ContentTypeMatcher;kotlin.Function1<0:0,kotlin.Unit>){0§}[0] final fun <#A1: io.ktor.serialization/ContentConverter> register(io.ktor.http/ContentType, #A1, kotlin/Function1<#A1, kotlin/Unit>) // io.ktor.client.plugins.contentnegotiation/ContentNegotiationConfig.register|register(io.ktor.http.ContentType;0:0;kotlin.Function1<0:0,kotlin.Unit>){0§}[0] final fun clearIgnoredTypes() // io.ktor.client.plugins.contentnegotiation/ContentNegotiationConfig.clearIgnoredTypes|clearIgnoredTypes(){}[0] @@ -28,3 +32,5 @@ final object io.ktor.client.plugins.contentnegotiation/JsonContentTypeMatcher : final val io.ktor.client.plugins.contentnegotiation/ContentNegotiation // io.ktor.client.plugins.contentnegotiation/ContentNegotiation|{}ContentNegotiation[0] final fun (): io.ktor.client.plugins.api/ClientPlugin // io.ktor.client.plugins.contentnegotiation/ContentNegotiation.|(){}[0] + +final fun (io.ktor.client.request/HttpRequestBuilder).io.ktor.client.plugins.contentnegotiation/exclude(kotlin/Array...) // io.ktor.client.plugins.contentnegotiation/exclude|exclude@io.ktor.client.request.HttpRequestBuilder(kotlin.Array...){}[0] diff --git a/ktor-client/ktor-client-plugins/ktor-client-content-negotiation/common/src/io/ktor/client/plugins/contentnegotiation/ContentNegotiation.kt b/ktor-client/ktor-client-plugins/ktor-client-content-negotiation/common/src/io/ktor/client/plugins/contentnegotiation/ContentNegotiation.kt index de812bd1ea9..21edbe71305 100644 --- a/ktor-client/ktor-client-plugins/ktor-client-content-negotiation/common/src/io/ktor/client/plugins/contentnegotiation/ContentNegotiation.kt +++ b/ktor-client/ktor-client-plugins/ktor-client-content-negotiation/common/src/io/ktor/client/plugins/contentnegotiation/ContentNegotiation.kt @@ -11,6 +11,7 @@ import io.ktor.client.utils.* import io.ktor.http.* import io.ktor.http.content.* import io.ktor.serialization.* +import io.ktor.util.AttributeKey import io.ktor.util.logging.* import io.ktor.util.reflect.* import io.ktor.utils.io.* @@ -29,6 +30,12 @@ internal val DefaultCommonIgnoredTypes: Set> = setOf( internal expect val DefaultIgnoredTypes: Set> +/** + * The content types that are excluded from the `Accept` header for this specific request. Use the + * [exclude] `HttpRequestBuilder` extension to set this attribute on a request. + */ +internal val ExcludedContentTypes: AttributeKey> = AttributeKey("ExcludedContentTypesAttr") + /** * A [ContentNegotiation] configuration that is used during installation. */ @@ -46,6 +53,12 @@ public class ContentNegotiationConfig : Configuration { internal val registrations = mutableListOf() + /** + * By default, `Accept` headers for registered content types will have no q value (implicit 1.0). Set this to + * change that behavior. This is useful to override the preferred `Accept` content types on a per-request basis. + */ + public var defaultAcceptHeaderQValue: Double? = null + /** * Registers a [contentType] to a specified [converter] with an optional [configuration] script for a converter. */ @@ -54,8 +67,8 @@ public class ContentNegotiationConfig : Configuration { converter: T, configuration: T.() -> Unit ) { - val matcher = when (contentType) { - ContentType.Application.Json -> JsonContentTypeMatcher + val matcher = when { + contentType.match(ContentType.Application.Json) -> JsonContentTypeMatcher else -> defaultMatcher(contentType) } register(contentType, converter, matcher, configuration) @@ -140,11 +153,25 @@ public val ContentNegotiation: ClientPlugin = createCl val ignoredTypes: Set> = pluginConfig.ignoredTypes suspend fun convertRequest(request: HttpRequestBuilder, body: Any): OutgoingContent? { - registrations.forEach { - LOGGER.trace("Adding Accept=${it.contentTypeToSend.contentType} header for ${request.url}") + val requestRegistrations = if (request.attributes.contains(ExcludedContentTypes)) { + val excluded = request.attributes[ExcludedContentTypes] + registrations.filter { registration -> excluded.none { registration.contentTypeToSend.match(it) } } + } else { + registrations + } - if (request.headers.contains(HttpHeaders.Accept, it.contentTypeToSend.toString())) return@forEach - request.accept(it.contentTypeToSend) + val acceptHeaders = request.headers.getAll(HttpHeaders.Accept).orEmpty() + requestRegistrations.forEach { + if (acceptHeaders.none { h -> ContentType.parse(h).match(it.contentTypeToSend) }) { + // automatically added headers get a lower content type priority, so user-specified accept headers + // with higher q or implicit q=1 will take precedence + val contentTypeToSend = when (val qValue = pluginConfig.defaultAcceptHeaderQValue) { + null -> it.contentTypeToSend + else -> it.contentTypeToSend.withParameter("q", qValue.toString()) + } + LOGGER.trace("Adding Accept=$contentTypeToSend header for ${request.url}") + request.accept(contentTypeToSend) + } } if (body is OutgoingContent || ignoredTypes.any { it.isInstance(body) }) { @@ -251,3 +278,14 @@ public val ContentNegotiation: ClientPlugin = createCl } public class ContentConverterException(message: String) : Exception(message) + +/** + * Excludes the given [ContentType] from the list of types that will be sent in the `Accept` header by + * the [ContentNegotiation] plugin. Can be used to not accept specific types for particular requests. + * This can be called multiple times to exclude multiple content types, or multiple content types can + * be passed in a single call. + */ +public fun HttpRequestBuilder.exclude(vararg contentType: ContentType) { + val excludedContentTypes = attributes.getOrNull(ExcludedContentTypes).orEmpty() + attributes.put(ExcludedContentTypes, excludedContentTypes + contentType) +} diff --git a/ktor-client/ktor-client-plugins/ktor-client-content-negotiation/common/test/io/ktor/client/plugins/ContentNegotiationTests.kt b/ktor-client/ktor-client-plugins/ktor-client-content-negotiation/common/test/io/ktor/client/plugins/ContentNegotiationTests.kt index 8812f098753..38783d52009 100644 --- a/ktor-client/ktor-client-plugins/ktor-client-content-negotiation/common/test/io/ktor/client/plugins/ContentNegotiationTests.kt +++ b/ktor-client/ktor-client-plugins/ktor-client-content-negotiation/common/test/io/ktor/client/plugins/ContentNegotiationTests.kt @@ -71,6 +71,166 @@ class ContentNegotiationTests { } } + @Test + fun addAcceptHeadersWithSingleExclusion() { + testWithEngine(MockEngine) { + val registeredTypesToSend = listOf( + ContentType("testing", "a"), + ContentType("testing", "b"), + ContentType("testing", "c") + ) + + setupWithContentNegotiation { + for (typeToSend in registeredTypesToSend) { + register(typeToSend, TestContentConverter()) + } + } + + test { client -> + client.get("https://test.com/") { + exclude(ContentType("testing", "b")) + }.apply { + val sentTypes = assertNotNull(call.request.headers.getAll(HttpHeaders.Accept)) + .map { ContentType.parse(it) } + + // Order NOT tested + for (typeToSend in registeredTypesToSend.filter { it.contentSubtype != "b" }) { + assertContains(sentTypes, typeToSend) + } + assertNull(sentTypes.firstOrNull { it.contentSubtype == "b" }) + } + } + } + } + + @Test + fun addAcceptHeadersWithExclusionMatchingParameterizedType() { + testWithEngine(MockEngine) { + val registeredTypesToSend = listOf( + ContentType("testing", "a").withParameter("foo", "bar"), + ContentType("testing", "b").withParameter("foo", "bar"), + ContentType("testing", "c").withParameter("foo", "bar") + ) + + setupWithContentNegotiation { + for (typeToSend in registeredTypesToSend) { + register(typeToSend, TestContentConverter()) + } + } + + test { client -> + client.get("https://test.com/") { + exclude(ContentType("testing", "b")) + }.apply { + val sentTypes = assertNotNull(call.request.headers.getAll(HttpHeaders.Accept)) + .map { ContentType.parse(it) } + + // Order NOT tested + for (typeToSend in registeredTypesToSend.filter { it.contentSubtype != "b" }) { + assertContains(sentTypes, typeToSend) + } + assertNull(sentTypes.firstOrNull { it.contentSubtype == "b" }) + } + } + } + } + + @Test + fun addAcceptHeadersWithMultipleExclusions() { + testWithEngine(MockEngine) { + val registeredTypesToSend = listOf( + ContentType("testing", "a"), + ContentType("testing", "b"), + ContentType("testing", "c") + ) + + setupWithContentNegotiation { + for (typeToSend in registeredTypesToSend) { + register(typeToSend, TestContentConverter()) + } + } + + test { client -> + client.get("https://test.com/") { + exclude(ContentType("testing", "b")) + exclude(ContentType("testing", "c")) + }.apply { + val sentTypes = assertNotNull(call.request.headers.getAll(HttpHeaders.Accept)) + .map { ContentType.parse(it) } + + // Order NOT tested + assertTrue(sentTypes.size == 1) + assertContains(sentTypes, ContentType("testing", "a")) + } + } + + test { client -> + client.get("https://test.com/") { + exclude(ContentType("testing", "b"), ContentType("testing", "c")) + }.apply { + val sentTypes = assertNotNull(call.request.headers.getAll(HttpHeaders.Accept)) + .map { ContentType.parse(it) } + + // Order NOT tested + assertTrue(sentTypes.size == 1) + assertContains(sentTypes, ContentType("testing", "a")) + } + } + } + } + + @Test + fun addAcceptHeadersWithDefaultQValue() { + testWithEngine(MockEngine) { + val registeredTypesToSend = listOf( + ContentType("testing", "a"), + ContentType("testing", "b"), + ContentType("testing", "c") + ) + + setupWithContentNegotiation { + for (typeToSend in registeredTypesToSend) { + register(typeToSend, TestContentConverter()) + defaultAcceptHeaderQValue = 0.8 + } + } + + test { client -> + client.get("https://test.com/").apply { + val sentTypes = assertNotNull(call.request.headers.getAll(HttpHeaders.Accept)) + .map { ContentType.parse(it) } + + // Order NOT tested + for (typeToSend in registeredTypesToSend) { + assertContains(sentTypes, typeToSend.withParameter("q", "0.8")) + } + } + } + } + } + + @Test + fun skipAddAcceptHeadersWithMatchingContentType() { + testWithEngine(MockEngine) { + setupWithContentNegotiation { + register(ContentType("testing", "a"), TestContentConverter()) + } + + test { client -> + client.get("https://test.com/") { + // our explicitly specified lower q-value should take precedence + accept(ContentType("testing", "a", listOf(HeaderValueParam("q", "0.5")))) + }.apply { + val sentTypes = assertNotNull(call.request.headers.getAll(HttpHeaders.Accept)) + .map { ContentType.parse(it) } + + assertContains(sentTypes, ContentType("testing", "a", listOf(HeaderValueParam("q", "0.5")))) + assertEquals(1, sentTypes.size) + } + } + } + } + @Test fun testKeepsContentType() { testWithEngine(MockEngine) { diff --git a/ktor-http/common/src/io/ktor/http/ContentTypes.kt b/ktor-http/common/src/io/ktor/http/ContentTypes.kt index c2444bc5c74..117c423d90e 100644 --- a/ktor-http/common/src/io/ktor/http/ContentTypes.kt +++ b/ktor-http/common/src/io/ktor/http/ContentTypes.kt @@ -54,6 +54,14 @@ public class ContentType private constructor( /** * Checks if `this` type matches a [pattern] type taking into account placeholder symbols `*` and parameters. + * The `this` type must be a more specific type than the [pattern] type. In other words: + * + * ```kotlin + * ContentType("a", "b").match(ContentType("a", "b").withParameter("foo", "bar")) === false + * ContentType("a", "b").withParameter("foo", "bar").match(ContentType("a", "b")) === true + * ContentType("a", "*").match(ContentType("a", "b")) === false + * ContentType("a", "b").match(ContentType("a", "*")) === true + * ``` */ public fun match(pattern: ContentType): Boolean { if (pattern.contentType != "*" && !pattern.contentType.equals(contentType, ignoreCase = true)) {