diff --git a/spring-web/src/main/java/org/springframework/web/client/DefaultRestClient.java b/spring-web/src/main/java/org/springframework/web/client/DefaultRestClient.java index da9b4d81456e..196f173246e0 100644 --- a/spring-web/src/main/java/org/springframework/web/client/DefaultRestClient.java +++ b/spring-web/src/main/java/org/springframework/web/client/DefaultRestClient.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2023 the original author or authors. + * Copyright 2002-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -99,6 +99,9 @@ final class DefaultRestClient implements RestClient { @Nullable private final HttpHeaders defaultHeaders; + @Nullable + private final Consumer> defaultRequest; + private final List defaultStatusHandlers; private final DefaultRestClientBuilder builder; @@ -116,6 +119,7 @@ final class DefaultRestClient implements RestClient { @Nullable List initializers, UriBuilderFactory uriBuilderFactory, @Nullable HttpHeaders defaultHeaders, + @Nullable Consumer> defaultRequest, @Nullable List statusHandlers, List> messageConverters, ObservationRegistry observationRegistry, @@ -127,6 +131,7 @@ final class DefaultRestClient implements RestClient { this.interceptors = interceptors; this.uriBuilderFactory = uriBuilderFactory; this.defaultHeaders = defaultHeaders; + this.defaultRequest = defaultRequest; this.defaultStatusHandlers = (statusHandlers != null ? new ArrayList<>(statusHandlers) : new ArrayList<>()); this.messageConverters = messageConverters; this.observationRegistry = observationRegistry; @@ -451,6 +456,9 @@ private T exchangeInternal(ExchangeFunction exchangeFunction, boolean clo Observation observation = null; URI uri = null; try { + if (DefaultRestClient.this.defaultRequest != null) { + DefaultRestClient.this.defaultRequest.accept(this); + } uri = initUri(); HttpHeaders headers = initHeaders(); ClientHttpRequest clientRequest = createRequest(uri); diff --git a/spring-web/src/main/java/org/springframework/web/client/DefaultRestClientBuilder.java b/spring-web/src/main/java/org/springframework/web/client/DefaultRestClientBuilder.java index 4aadfdd7a67a..2dfd0d7bdc86 100644 --- a/spring-web/src/main/java/org/springframework/web/client/DefaultRestClientBuilder.java +++ b/spring-web/src/main/java/org/springframework/web/client/DefaultRestClientBuilder.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2023 the original author or authors. + * Copyright 2002-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -371,6 +371,7 @@ public RestClient build() { return new DefaultRestClient(requestFactory, this.interceptors, this.initializers, uriBuilderFactory, defaultHeaders, + this.defaultRequest, this.statusHandlers, messageConverters, this.observationRegistry, diff --git a/spring-web/src/test/java/org/springframework/web/client/RestClientIntegrationTests.java b/spring-web/src/test/java/org/springframework/web/client/RestClientIntegrationTests.java index 92424a251d25..673ae08cc799 100644 --- a/spring-web/src/test/java/org/springframework/web/client/RestClientIntegrationTests.java +++ b/spring-web/src/test/java/org/springframework/web/client/RestClientIntegrationTests.java @@ -856,6 +856,50 @@ void filterForErrorHandling(ClientHttpRequestFactory requestFactory) { expectRequestCount(2); } + @ParameterizedRestClientTest + void defaultHeaders(ClientHttpRequestFactory requestFactory) { + startServer(requestFactory); + + prepareResponse(response -> response.setHeader("Content-Type", "text/plain") + .setBody("Hello Spring!")); + + RestClient headersClient = this.restClient.mutate() + .defaultHeaders(headers -> headers.add("foo", "bar")) + .build(); + + String result = headersClient.get() + .uri("/greeting") + .retrieve() + .body(String.class); + + assertThat(result).isEqualTo("Hello Spring!"); + + expectRequestCount(1); + expectRequest(request -> assertThat(request.getHeader("foo")).isEqualTo("bar")); + } + + @ParameterizedRestClientTest + void defaultRequest(ClientHttpRequestFactory requestFactory) { + startServer(requestFactory); + + prepareResponse(response -> response.setHeader("Content-Type", "text/plain") + .setBody("Hello Spring!")); + + RestClient headersClient = this.restClient.mutate() + .defaultRequest(request -> request.header("foo", "bar")) + .build(); + + String result = headersClient.get() + .uri("/greeting") + .retrieve() + .body(String.class); + + assertThat(result).isEqualTo("Hello Spring!"); + + expectRequestCount(1); + expectRequest(request -> assertThat(request.getHeader("foo")).isEqualTo("bar")); + } + private void prepareResponse(Consumer consumer) { MockResponse response = new MockResponse();