Skip to content

Commit

Permalink
Used DataSource.unwrap(...) method alongside with 'instance of' for d…
Browse files Browse the repository at this point in the history
…etermining AbstractRoutingDataSource
  • Loading branch information
nosan committed Sep 15, 2024
1 parent d3a2bf4 commit d5aa1a9
Show file tree
Hide file tree
Showing 2 changed files with 177 additions and 7 deletions.
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2012-2023 the original author or authors.
* Copyright 2012-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -16,6 +16,7 @@

package org.springframework.boot.actuate.autoconfigure.jdbc;

import java.sql.SQLException;
import java.util.Collection;
import java.util.Iterator;
import java.util.Map;
Expand Down Expand Up @@ -88,7 +89,7 @@ public HealthContributor dbHealthContributor(Map<String, DataSource> dataSources
if (dataSourceHealthIndicatorProperties.isIgnoreRoutingDataSources()) {
Map<String, DataSource> filteredDatasources = dataSources.entrySet()
.stream()
.filter((e) -> !(e.getValue() instanceof AbstractRoutingDataSource))
.filter((e) -> !isAbstractRoutingDataSource(e.getValue()))
.collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue));
return createContributor(filteredDatasources);
}
Expand All @@ -104,8 +105,9 @@ private HealthContributor createContributor(Map<String, DataSource> beans) {
}

private HealthContributor createContributor(DataSource source) {
if (source instanceof AbstractRoutingDataSource routingDataSource) {
return new RoutingDataSourceHealthContributor(routingDataSource, this::createContributor);
if (isAbstractRoutingDataSource(source)) {
return new RoutingDataSourceHealthContributor(unwrapAbstractRoutingDataSource(source),
this::createContributor);
}
return new DataSourceHealthIndicator(source, getValidationQuery(source));
}
Expand All @@ -115,6 +117,31 @@ private String getValidationQuery(DataSource source) {
return (poolMetadata != null) ? poolMetadata.getValidationQuery() : null;
}

private static boolean isAbstractRoutingDataSource(DataSource dataSource) {
if (dataSource instanceof AbstractRoutingDataSource) {
return true;
}
try {
return dataSource.isWrapperFor(AbstractRoutingDataSource.class);
}
catch (SQLException ex) {
return false;
}
}

private static AbstractRoutingDataSource unwrapAbstractRoutingDataSource(DataSource dataSource) {
if (dataSource instanceof AbstractRoutingDataSource routingDataSource) {
return routingDataSource;
}
try {
return dataSource.unwrap(AbstractRoutingDataSource.class);
}
catch (SQLException ex) {
throw new IllegalStateException(
"DataSource '%s' failed to unwrap '%s'".formatted(dataSource, AbstractRoutingDataSource.class), ex);
}
}

/**
* {@link CompositeHealthContributor} used for {@link AbstractRoutingDataSource} beans
* where the overall health is composed of a {@link DataSourceHealthIndicator} for
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2012-2023 the original author or authors.
* Copyright 2012-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -16,13 +16,22 @@

package org.springframework.boot.actuate.autoconfigure.jdbc;

import java.io.PrintWriter;
import java.sql.Connection;
import java.sql.ConnectionBuilder;
import java.sql.SQLException;
import java.sql.SQLFeatureNotSupportedException;
import java.sql.ShardingKeyBuilder;
import java.util.HashMap;
import java.util.Map;
import java.util.logging.Logger;

import javax.sql.DataSource;

import org.junit.jupiter.api.Test;

import org.springframework.beans.BeansException;
import org.springframework.beans.factory.config.BeanPostProcessor;
import org.springframework.boot.actuate.autoconfigure.health.HealthContributorAutoConfiguration;
import org.springframework.boot.actuate.autoconfigure.jdbc.DataSourceHealthContributorAutoConfiguration.RoutingDataSourceHealthContributor;
import org.springframework.boot.actuate.health.CompositeHealthContributor;
Expand Down Expand Up @@ -87,6 +96,19 @@ void runWithRoutingAndEmbeddedDataSourceShouldIncludeRoutingDataSource() {
});
}

@Test
void runWithProxyBeanPostProcessorRoutingAndEmbeddedDataSourceShouldIncludeRoutingDataSource() {
this.contextRunner
.withUserConfiguration(ProxyDataSourceBeanPostProcessor.class, EmbeddedDataSourceConfiguration.class,
RoutingDataSourceConfig.class)
.run((context) -> {
CompositeHealthContributor composite = context.getBean(CompositeHealthContributor.class);
assertThat(composite.getContributor("dataSource")).isInstanceOf(DataSourceHealthIndicator.class);
assertThat(composite.getContributor("routingDataSource"))
.isInstanceOf(RoutingDataSourceHealthContributor.class);
});
}

@Test
void runWithRoutingAndEmbeddedDataSourceShouldNotIncludeRoutingDataSourceWhenIgnored() {
this.contextRunner.withUserConfiguration(EmbeddedDataSourceConfiguration.class, RoutingDataSourceConfig.class)
Expand All @@ -98,6 +120,19 @@ void runWithRoutingAndEmbeddedDataSourceShouldNotIncludeRoutingDataSourceWhenIgn
});
}

@Test
void runWithProxyBeanPostProcessorAndRoutingAndEmbeddedDataSourceShouldNotIncludeRoutingDataSourceWhenIgnored() {
this.contextRunner
.withUserConfiguration(ProxyDataSourceBeanPostProcessor.class, EmbeddedDataSourceConfiguration.class,
RoutingDataSourceConfig.class)
.withPropertyValues("management.health.db.ignore-routing-datasources:true")
.run((context) -> {
assertThat(context).doesNotHaveBean(CompositeHealthContributor.class);
assertThat(context).hasSingleBean(DataSourceHealthIndicator.class);
assertThat(context).doesNotHaveBean(RoutingDataSourceHealthContributor.class);
});
}

@Test
void runWithOnlyRoutingDataSourceShouldIncludeRoutingDataSourceWithComposedIndicators() {
this.contextRunner.withUserConfiguration(RoutingDataSourceConfig.class).run((context) -> {
Expand All @@ -112,6 +147,23 @@ void runWithOnlyRoutingDataSourceShouldIncludeRoutingDataSourceWithComposedIndic
});
}

@Test
void runWithProxyBeanPostProcessorAndRoutingDataSourceShouldIncludeRoutingDataSourceWithComposedIndicators() {
this.contextRunner.withUserConfiguration(ProxyDataSourceBeanPostProcessor.class, RoutingDataSourceConfig.class)
.run((context) -> {
assertThat(context).hasSingleBean(RoutingDataSourceHealthContributor.class);
RoutingDataSourceHealthContributor routingHealthContributor = context
.getBean(RoutingDataSourceHealthContributor.class);
assertThat(routingHealthContributor.getContributor("one"))
.isInstanceOf(DataSourceHealthIndicator.class);
assertThat(routingHealthContributor.getContributor("two"))
.isInstanceOf(DataSourceHealthIndicator.class);
assertThat(routingHealthContributor.iterator()).toIterable()
.extracting("name")
.containsExactlyInAnyOrder("one", "two");
});
}

@Test
void runWithOnlyRoutingDataSourceShouldCrashWhenIgnored() {
this.contextRunner.withUserConfiguration(RoutingDataSourceConfig.class)
Expand All @@ -121,6 +173,15 @@ void runWithOnlyRoutingDataSourceShouldCrashWhenIgnored() {
.hasRootCauseInstanceOf(IllegalArgumentException.class));
}

@Test
void runWithProxyBeanPostProcessorAndOnlyRoutingDataSourceShouldCrashWhenIgnored() {
this.contextRunner.withUserConfiguration(ProxyDataSourceBeanPostProcessor.class, RoutingDataSourceConfig.class)
.withPropertyValues("management.health.db.ignore-routing-datasources:true")
.run((context) -> assertThat(context).hasFailed()
.getFailure()
.hasRootCauseInstanceOf(IllegalArgumentException.class));
}

@Test
void runWithValidationQueryPropertyShouldUseCustomQuery() {
this.contextRunner
Expand Down Expand Up @@ -177,30 +238,112 @@ DataSource testDataSource() {
static class RoutingDataSourceConfig {

@Bean
AbstractRoutingDataSource routingDataSource() {
AbstractRoutingDataSource routingDataSource() throws SQLException {
Map<Object, DataSource> dataSources = new HashMap<>();
dataSources.put("one", mock(DataSource.class));
dataSources.put("two", mock(DataSource.class));
AbstractRoutingDataSource routingDataSource = mock(AbstractRoutingDataSource.class);
given(routingDataSource.isWrapperFor(AbstractRoutingDataSource.class)).willReturn(true);
given(routingDataSource.unwrap(AbstractRoutingDataSource.class)).willReturn(routingDataSource);
given(routingDataSource.getResolvedDataSources()).willReturn(dataSources);
return routingDataSource;
}

}

static class ProxyDataSourceBeanPostProcessor implements BeanPostProcessor {

@Override
public Object postProcessBeforeInitialization(Object bean, String beanName) throws BeansException {
if (bean instanceof DataSource dataSource) {
return new ProxyDataSource(dataSource);
}
return bean;
}

}

@Configuration(proxyBeanMethods = false)
static class NullKeyRoutingDataSourceConfig {

@Bean
AbstractRoutingDataSource routingDataSource() {
AbstractRoutingDataSource routingDataSource() throws Exception {
Map<Object, DataSource> dataSources = new HashMap<>();
dataSources.put(null, mock(DataSource.class));
dataSources.put("one", mock(DataSource.class));
AbstractRoutingDataSource routingDataSource = mock(AbstractRoutingDataSource.class);
given(routingDataSource.isWrapperFor(AbstractRoutingDataSource.class)).willReturn(true);
given(routingDataSource.unwrap(AbstractRoutingDataSource.class)).willReturn(routingDataSource);
given(routingDataSource.getResolvedDataSources()).willReturn(dataSources);
return routingDataSource;
}

}

static class ProxyDataSource implements DataSource {

private final DataSource dataSource;

ProxyDataSource(DataSource dataSource) {
this.dataSource = dataSource;
}

@Override
public void setLogWriter(PrintWriter out) throws SQLException {
this.dataSource.setLogWriter(out);
}

@Override
public Connection getConnection() throws SQLException {
return this.dataSource.getConnection();
}

@Override
public Connection getConnection(String username, String password) throws SQLException {
return this.dataSource.getConnection(username, password);
}

@Override
public PrintWriter getLogWriter() throws SQLException {
return this.dataSource.getLogWriter();
}

@Override
public void setLoginTimeout(int seconds) throws SQLException {
this.dataSource.setLoginTimeout(seconds);
}

@Override
public int getLoginTimeout() throws SQLException {
return this.dataSource.getLoginTimeout();
}

@Override
public ConnectionBuilder createConnectionBuilder() throws SQLException {
return this.dataSource.createConnectionBuilder();
}

@Override
public Logger getParentLogger() throws SQLFeatureNotSupportedException {
return this.dataSource.getParentLogger();
}

@Override
public ShardingKeyBuilder createShardingKeyBuilder() throws SQLException {
return this.dataSource.createShardingKeyBuilder();
}

@Override
@SuppressWarnings("unchecked")
public <T> T unwrap(Class<T> iface) throws SQLException {
return iface.isInstance(this) ? (T) this : this.dataSource.unwrap(iface);
}

@Override
public boolean isWrapperFor(Class<?> iface) throws SQLException {
return (iface.isInstance(this) || this.dataSource.isWrapperFor(iface));
}

}

}

0 comments on commit d5aa1a9

Please sign in to comment.