diff --git a/spring-web/src/main/java/org/springframework/web/client/DefaultRestClient.java b/spring-web/src/main/java/org/springframework/web/client/DefaultRestClient.java index da9b4d81456e..1b137ce8456e 100644 --- a/spring-web/src/main/java/org/springframework/web/client/DefaultRestClient.java +++ b/spring-web/src/main/java/org/springframework/web/client/DefaultRestClient.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2023 the original author or authors. + * Copyright 2002-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -71,6 +71,7 @@ * * @author Arjen Poutsma * @author Sebastien Deleuze + * @author Injae Kim * @since 6.1 * @see RestClient#create() * @see RestClient#create(String) @@ -99,6 +100,9 @@ final class DefaultRestClient implements RestClient { @Nullable private final HttpHeaders defaultHeaders; + @Nullable + private final Consumer> defaultRequest; + private final List defaultStatusHandlers; private final DefaultRestClientBuilder builder; @@ -116,6 +120,7 @@ final class DefaultRestClient implements RestClient { @Nullable List initializers, UriBuilderFactory uriBuilderFactory, @Nullable HttpHeaders defaultHeaders, + @Nullable Consumer> defaultRequest, @Nullable List statusHandlers, List> messageConverters, ObservationRegistry observationRegistry, @@ -127,6 +132,7 @@ final class DefaultRestClient implements RestClient { this.interceptors = interceptors; this.uriBuilderFactory = uriBuilderFactory; this.defaultHeaders = defaultHeaders; + this.defaultRequest = defaultRequest; this.defaultStatusHandlers = (statusHandlers != null ? new ArrayList<>(statusHandlers) : new ArrayList<>()); this.messageConverters = messageConverters; this.observationRegistry = observationRegistry; @@ -452,6 +458,9 @@ private T exchangeInternal(ExchangeFunction exchangeFunction, boolean clo URI uri = null; try { uri = initUri(); + if (defaultRequest != null) { + defaultRequest.accept(this); + } HttpHeaders headers = initHeaders(); ClientHttpRequest clientRequest = createRequest(uri); clientRequest.getHeaders().addAll(headers); diff --git a/spring-web/src/main/java/org/springframework/web/client/DefaultRestClientBuilder.java b/spring-web/src/main/java/org/springframework/web/client/DefaultRestClientBuilder.java index 4aadfdd7a67a..3a28596afc01 100644 --- a/spring-web/src/main/java/org/springframework/web/client/DefaultRestClientBuilder.java +++ b/spring-web/src/main/java/org/springframework/web/client/DefaultRestClientBuilder.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2023 the original author or authors. + * Copyright 2002-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -58,6 +58,7 @@ * Default implementation of {@link RestClient.Builder}. * * @author Arjen Poutsma + * @author Injae Kim * @since 6.1 */ final class DefaultRestClientBuilder implements RestClient.Builder { @@ -371,6 +372,7 @@ public RestClient build() { return new DefaultRestClient(requestFactory, this.interceptors, this.initializers, uriBuilderFactory, defaultHeaders, + this.defaultRequest, this.statusHandlers, messageConverters, this.observationRegistry, diff --git a/spring-web/src/test/java/org/springframework/web/client/RestClientIntegrationTests.java b/spring-web/src/test/java/org/springframework/web/client/RestClientIntegrationTests.java index 92424a251d25..7d85c4bbf46a 100644 --- a/spring-web/src/test/java/org/springframework/web/client/RestClientIntegrationTests.java +++ b/spring-web/src/test/java/org/springframework/web/client/RestClientIntegrationTests.java @@ -64,6 +64,7 @@ * * @author Arjen Poutsma * @author Sebastien Deleuze + * @author Injae Kim */ class RestClientIntegrationTests { @@ -856,6 +857,35 @@ void filterForErrorHandling(ClientHttpRequestFactory requestFactory) { expectRequestCount(2); } + @ParameterizedRestClientTest + void defaultRequest(ClientHttpRequestFactory requestFactory) { + startServer(requestFactory); + + prepareResponse(response -> + response.setHeader("Content-Type", "text/plain").setBody("Hello Spring!")); + + String result = this.restClient.mutate() + .defaultRequest(spec -> spec + .header("X-Test-Header-Default", "testDefaultValue") + .header("X-Test-Header", "testDefaultValueShouldBeOverride")) + .build() + .get() + .uri("/greeting") + .header("X-Test-Header", "testHeaderValue") + .accept(MediaType.APPLICATION_JSON) + .retrieve() + .body(String.class); + + assertThat(result).isEqualTo("Hello Spring!"); + + expectRequestCount(1); + expectRequest(request -> { + assertThat(request.getHeader("X-Test-Header-Default")).isEqualTo("testDefaultValue"); + assertThat(request.getHeader("X-Test-Header")).isEqualTo("testHeaderValue"); + assertThat(request.getHeader(HttpHeaders.ACCEPT)).isEqualTo("application/json"); + assertThat(request.getPath()).isEqualTo("/greeting"); + }); + } private void prepareResponse(Consumer consumer) { MockResponse response = new MockResponse();