Skip to content

Commit

Permalink
Use DataSource.unwrap to get routing data source
Browse files Browse the repository at this point in the history
This commit uses DataSource.isWrapperFor and DataSource.unwrap to detect
if a DataSource is an AbstractRoutingDataSource. Previously, it relied
on instanceof which does not account for cases where the datasource has
been proxied.

See gh-42313
  • Loading branch information
nosan authored and snicoll committed Sep 16, 2024
1 parent 99d0805 commit 3f9f049
Show file tree
Hide file tree
Showing 2 changed files with 177 additions and 7 deletions.
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2012-2023 the original author or authors.
* Copyright 2012-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -16,6 +16,7 @@

package org.springframework.boot.actuate.autoconfigure.jdbc;

import java.sql.SQLException;
import java.util.Collection;
import java.util.Iterator;
import java.util.Map;
Expand Down Expand Up @@ -88,7 +89,7 @@ public HealthContributor dbHealthContributor(Map<String, DataSource> dataSources
if (dataSourceHealthIndicatorProperties.isIgnoreRoutingDataSources()) {
Map<String, DataSource> filteredDatasources = dataSources.entrySet()
.stream()
.filter((e) -> !(e.getValue() instanceof AbstractRoutingDataSource))
.filter((e) -> !isAbstractRoutingDataSource(e.getValue()))
.collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue));
return createContributor(filteredDatasources);
}
Expand All @@ -104,8 +105,9 @@ private HealthContributor createContributor(Map<String, DataSource> beans) {
}

private HealthContributor createContributor(DataSource source) {
if (source instanceof AbstractRoutingDataSource routingDataSource) {
return new RoutingDataSourceHealthContributor(routingDataSource, this::createContributor);
if (isAbstractRoutingDataSource(source)) {
return new RoutingDataSourceHealthContributor(unwrapAbstractRoutingDataSource(source),
this::createContributor);
}
return new DataSourceHealthIndicator(source, getValidationQuery(source));
}
Expand All @@ -115,6 +117,31 @@ private String getValidationQuery(DataSource source) {
return (poolMetadata != null) ? poolMetadata.getValidationQuery() : null;
}

private static boolean isAbstractRoutingDataSource(DataSource dataSource) {
if (dataSource instanceof AbstractRoutingDataSource) {
return true;
}
try {
return dataSource.isWrapperFor(AbstractRoutingDataSource.class);
}
catch (SQLException ex) {
return false;
}
}

private static AbstractRoutingDataSource unwrapAbstractRoutingDataSource(DataSource dataSource) {
if (dataSource instanceof AbstractRoutingDataSource routingDataSource) {
return routingDataSource;
}
try {
return dataSource.unwrap(AbstractRoutingDataSource.class);
}
catch (SQLException ex) {
throw new IllegalStateException(
"DataSource '%s' failed to unwrap '%s'".formatted(dataSource, AbstractRoutingDataSource.class), ex);
}
}

/**
* {@link CompositeHealthContributor} used for {@link AbstractRoutingDataSource} beans
* where the overall health is composed of a {@link DataSourceHealthIndicator} for
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2012-2023 the original author or authors.
* Copyright 2012-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -16,13 +16,22 @@

package org.springframework.boot.actuate.autoconfigure.jdbc;

import java.io.PrintWriter;
import java.sql.Connection;
import java.sql.ConnectionBuilder;
import java.sql.SQLException;
import java.sql.SQLFeatureNotSupportedException;
import java.sql.ShardingKeyBuilder;
import java.util.HashMap;
import java.util.Map;
import java.util.logging.Logger;

import javax.sql.DataSource;

import org.junit.jupiter.api.Test;

import org.springframework.beans.BeansException;
import org.springframework.beans.factory.config.BeanPostProcessor;
import org.springframework.boot.actuate.autoconfigure.health.HealthContributorAutoConfiguration;
import org.springframework.boot.actuate.autoconfigure.jdbc.DataSourceHealthContributorAutoConfiguration.RoutingDataSourceHealthContributor;
import org.springframework.boot.actuate.health.CompositeHealthContributor;
Expand Down Expand Up @@ -87,6 +96,19 @@ void runWithRoutingAndEmbeddedDataSourceShouldIncludeRoutingDataSource() {
});
}

@Test
void runWithProxyBeanPostProcessorRoutingAndEmbeddedDataSourceShouldIncludeRoutingDataSource() {
this.contextRunner
.withUserConfiguration(ProxyDataSourceBeanPostProcessor.class, EmbeddedDataSourceConfiguration.class,
RoutingDataSourceConfig.class)
.run((context) -> {
CompositeHealthContributor composite = context.getBean(CompositeHealthContributor.class);
assertThat(composite.getContributor("dataSource")).isInstanceOf(DataSourceHealthIndicator.class);
assertThat(composite.getContributor("routingDataSource"))
.isInstanceOf(RoutingDataSourceHealthContributor.class);
});
}

@Test
void runWithRoutingAndEmbeddedDataSourceShouldNotIncludeRoutingDataSourceWhenIgnored() {
this.contextRunner.withUserConfiguration(EmbeddedDataSourceConfiguration.class, RoutingDataSourceConfig.class)
Expand All @@ -98,6 +120,19 @@ void runWithRoutingAndEmbeddedDataSourceShouldNotIncludeRoutingDataSourceWhenIgn
});
}

@Test
void runWithProxyBeanPostProcessorAndRoutingAndEmbeddedDataSourceShouldNotIncludeRoutingDataSourceWhenIgnored() {
this.contextRunner
.withUserConfiguration(ProxyDataSourceBeanPostProcessor.class, EmbeddedDataSourceConfiguration.class,
RoutingDataSourceConfig.class)
.withPropertyValues("management.health.db.ignore-routing-datasources:true")
.run((context) -> {
assertThat(context).doesNotHaveBean(CompositeHealthContributor.class);
assertThat(context).hasSingleBean(DataSourceHealthIndicator.class);
assertThat(context).doesNotHaveBean(RoutingDataSourceHealthContributor.class);
});
}

@Test
void runWithOnlyRoutingDataSourceShouldIncludeRoutingDataSourceWithComposedIndicators() {
this.contextRunner.withUserConfiguration(RoutingDataSourceConfig.class).run((context) -> {
Expand All @@ -112,6 +147,23 @@ void runWithOnlyRoutingDataSourceShouldIncludeRoutingDataSourceWithComposedIndic
});
}

@Test
void runWithProxyBeanPostProcessorAndRoutingDataSourceShouldIncludeRoutingDataSourceWithComposedIndicators() {
this.contextRunner.withUserConfiguration(ProxyDataSourceBeanPostProcessor.class, RoutingDataSourceConfig.class)
.run((context) -> {
assertThat(context).hasSingleBean(RoutingDataSourceHealthContributor.class);
RoutingDataSourceHealthContributor routingHealthContributor = context
.getBean(RoutingDataSourceHealthContributor.class);
assertThat(routingHealthContributor.getContributor("one"))
.isInstanceOf(DataSourceHealthIndicator.class);
assertThat(routingHealthContributor.getContributor("two"))
.isInstanceOf(DataSourceHealthIndicator.class);
assertThat(routingHealthContributor.iterator()).toIterable()
.extracting("name")
.containsExactlyInAnyOrder("one", "two");
});
}

@Test
void runWithOnlyRoutingDataSourceShouldCrashWhenIgnored() {
this.contextRunner.withUserConfiguration(RoutingDataSourceConfig.class)
Expand All @@ -121,6 +173,15 @@ void runWithOnlyRoutingDataSourceShouldCrashWhenIgnored() {
.hasRootCauseInstanceOf(IllegalArgumentException.class));
}

@Test
void runWithProxyBeanPostProcessorAndOnlyRoutingDataSourceShouldCrashWhenIgnored() {
this.contextRunner.withUserConfiguration(ProxyDataSourceBeanPostProcessor.class, RoutingDataSourceConfig.class)
.withPropertyValues("management.health.db.ignore-routing-datasources:true")
.run((context) -> assertThat(context).hasFailed()
.getFailure()
.hasRootCauseInstanceOf(IllegalArgumentException.class));
}

@Test
void runWithValidationQueryPropertyShouldUseCustomQuery() {
this.contextRunner
Expand Down Expand Up @@ -177,30 +238,112 @@ DataSource testDataSource() {
static class RoutingDataSourceConfig {

@Bean
AbstractRoutingDataSource routingDataSource() {
AbstractRoutingDataSource routingDataSource() throws SQLException {
Map<Object, DataSource> dataSources = new HashMap<>();
dataSources.put("one", mock(DataSource.class));
dataSources.put("two", mock(DataSource.class));
AbstractRoutingDataSource routingDataSource = mock(AbstractRoutingDataSource.class);
given(routingDataSource.isWrapperFor(AbstractRoutingDataSource.class)).willReturn(true);
given(routingDataSource.unwrap(AbstractRoutingDataSource.class)).willReturn(routingDataSource);
given(routingDataSource.getResolvedDataSources()).willReturn(dataSources);
return routingDataSource;
}

}

static class ProxyDataSourceBeanPostProcessor implements BeanPostProcessor {

@Override
public Object postProcessBeforeInitialization(Object bean, String beanName) throws BeansException {
if (bean instanceof DataSource dataSource) {
return new ProxyDataSource(dataSource);
}
return bean;
}

}

@Configuration(proxyBeanMethods = false)
static class NullKeyRoutingDataSourceConfig {

@Bean
AbstractRoutingDataSource routingDataSource() {
AbstractRoutingDataSource routingDataSource() throws Exception {
Map<Object, DataSource> dataSources = new HashMap<>();
dataSources.put(null, mock(DataSource.class));
dataSources.put("one", mock(DataSource.class));
AbstractRoutingDataSource routingDataSource = mock(AbstractRoutingDataSource.class);
given(routingDataSource.isWrapperFor(AbstractRoutingDataSource.class)).willReturn(true);
given(routingDataSource.unwrap(AbstractRoutingDataSource.class)).willReturn(routingDataSource);
given(routingDataSource.getResolvedDataSources()).willReturn(dataSources);
return routingDataSource;
}

}

static class ProxyDataSource implements DataSource {

private final DataSource dataSource;

ProxyDataSource(DataSource dataSource) {
this.dataSource = dataSource;
}

@Override
public void setLogWriter(PrintWriter out) throws SQLException {
this.dataSource.setLogWriter(out);
}

@Override
public Connection getConnection() throws SQLException {
return this.dataSource.getConnection();
}

@Override
public Connection getConnection(String username, String password) throws SQLException {
return this.dataSource.getConnection(username, password);
}

@Override
public PrintWriter getLogWriter() throws SQLException {
return this.dataSource.getLogWriter();
}

@Override
public void setLoginTimeout(int seconds) throws SQLException {
this.dataSource.setLoginTimeout(seconds);
}

@Override
public int getLoginTimeout() throws SQLException {
return this.dataSource.getLoginTimeout();
}

@Override
public ConnectionBuilder createConnectionBuilder() throws SQLException {
return this.dataSource.createConnectionBuilder();
}

@Override
public Logger getParentLogger() throws SQLFeatureNotSupportedException {
return this.dataSource.getParentLogger();
}

@Override
public ShardingKeyBuilder createShardingKeyBuilder() throws SQLException {
return this.dataSource.createShardingKeyBuilder();
}

@Override
@SuppressWarnings("unchecked")
public <T> T unwrap(Class<T> iface) throws SQLException {
return iface.isInstance(this) ? (T) this : this.dataSource.unwrap(iface);
}

@Override
public boolean isWrapperFor(Class<?> iface) throws SQLException {
return (iface.isInstance(this) || this.dataSource.isWrapperFor(iface));
}

}

}

0 comments on commit 3f9f049

Please sign in to comment.