Skip to content

Commit

Permalink
Polish "Use DataSource.unwrap to get routing data source"
Browse files Browse the repository at this point in the history
  • Loading branch information
snicoll committed Sep 16, 2024
1 parent 3f9f049 commit 78a140a
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 81 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ public HealthContributor dbHealthContributor(Map<String, DataSource> dataSources
if (dataSourceHealthIndicatorProperties.isIgnoreRoutingDataSources()) {
Map<String, DataSource> filteredDatasources = dataSources.entrySet()
.stream()
.filter((e) -> !isAbstractRoutingDataSource(e.getValue()))
.filter((e) -> !isRoutingDataSource(e.getValue()))
.collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue));
return createContributor(filteredDatasources);
}
Expand All @@ -105,9 +105,8 @@ private HealthContributor createContributor(Map<String, DataSource> beans) {
}

private HealthContributor createContributor(DataSource source) {
if (isAbstractRoutingDataSource(source)) {
return new RoutingDataSourceHealthContributor(unwrapAbstractRoutingDataSource(source),
this::createContributor);
if (isRoutingDataSource(source)) {
return new RoutingDataSourceHealthContributor(extractRoutingDataSource(source), this::createContributor);
}
return new DataSourceHealthIndicator(source, getValidationQuery(source));
}
Expand All @@ -117,7 +116,7 @@ private String getValidationQuery(DataSource source) {
return (poolMetadata != null) ? poolMetadata.getValidationQuery() : null;
}

private static boolean isAbstractRoutingDataSource(DataSource dataSource) {
private static boolean isRoutingDataSource(DataSource dataSource) {
if (dataSource instanceof AbstractRoutingDataSource) {
return true;
}
Expand All @@ -129,16 +128,15 @@ private static boolean isAbstractRoutingDataSource(DataSource dataSource) {
}
}

private static AbstractRoutingDataSource unwrapAbstractRoutingDataSource(DataSource dataSource) {
private static AbstractRoutingDataSource extractRoutingDataSource(DataSource dataSource) {
if (dataSource instanceof AbstractRoutingDataSource routingDataSource) {
return routingDataSource;
}
try {
return dataSource.unwrap(AbstractRoutingDataSource.class);
}
catch (SQLException ex) {
throw new IllegalStateException(
"DataSource '%s' failed to unwrap '%s'".formatted(dataSource, AbstractRoutingDataSource.class), ex);
throw new IllegalStateException("Failed to unwrap AbstractRoutingDataSource from " + dataSource, ex);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,9 @@

package org.springframework.boot.actuate.autoconfigure.jdbc;

import java.io.PrintWriter;
import java.sql.Connection;
import java.sql.ConnectionBuilder;
import java.sql.SQLException;
import java.sql.SQLFeatureNotSupportedException;
import java.sql.ShardingKeyBuilder;
import java.util.HashMap;
import java.util.Map;
import java.util.logging.Logger;

import javax.sql.DataSource;

Expand Down Expand Up @@ -256,11 +250,24 @@ static class ProxyDataSourceBeanPostProcessor implements BeanPostProcessor {
@Override
public Object postProcessBeforeInitialization(Object bean, String beanName) throws BeansException {
if (bean instanceof DataSource dataSource) {
return new ProxyDataSource(dataSource);
return proxyDataSource(dataSource);
}
return bean;
}

private static DataSource proxyDataSource(DataSource dataSource) {
try {
DataSource mock = mock(DataSource.class);
given(mock.isWrapperFor(AbstractRoutingDataSource.class))
.willReturn(dataSource instanceof AbstractRoutingDataSource);
given(mock.unwrap(AbstractRoutingDataSource.class)).willAnswer((invocation) -> dataSource);
return mock;
}
catch (SQLException ex) {
throw new IllegalStateException(ex);
}
}

}

@Configuration(proxyBeanMethods = false)
Expand All @@ -280,70 +287,4 @@ AbstractRoutingDataSource routingDataSource() throws Exception {

}

static class ProxyDataSource implements DataSource {

private final DataSource dataSource;

ProxyDataSource(DataSource dataSource) {
this.dataSource = dataSource;
}

@Override
public void setLogWriter(PrintWriter out) throws SQLException {
this.dataSource.setLogWriter(out);
}

@Override
public Connection getConnection() throws SQLException {
return this.dataSource.getConnection();
}

@Override
public Connection getConnection(String username, String password) throws SQLException {
return this.dataSource.getConnection(username, password);
}

@Override
public PrintWriter getLogWriter() throws SQLException {
return this.dataSource.getLogWriter();
}

@Override
public void setLoginTimeout(int seconds) throws SQLException {
this.dataSource.setLoginTimeout(seconds);
}

@Override
public int getLoginTimeout() throws SQLException {
return this.dataSource.getLoginTimeout();
}

@Override
public ConnectionBuilder createConnectionBuilder() throws SQLException {
return this.dataSource.createConnectionBuilder();
}

@Override
public Logger getParentLogger() throws SQLFeatureNotSupportedException {
return this.dataSource.getParentLogger();
}

@Override
public ShardingKeyBuilder createShardingKeyBuilder() throws SQLException {
return this.dataSource.createShardingKeyBuilder();
}

@Override
@SuppressWarnings("unchecked")
public <T> T unwrap(Class<T> iface) throws SQLException {
return iface.isInstance(this) ? (T) this : this.dataSource.unwrap(iface);
}

@Override
public boolean isWrapperFor(Class<?> iface) throws SQLException {
return (iface.isInstance(this) || this.dataSource.isWrapperFor(iface));
}

}

}

0 comments on commit 78a140a

Please sign in to comment.