Skip to content

Commit

Permalink
Merge pull request #12 from eeichinger/concurrency_bug
Browse files Browse the repository at this point in the history
#11 fix concurrent modification problem in MockConnection
  • Loading branch information
eeichinger authored Jul 25, 2016
2 parents 376754b + 9847176 commit d5cf8e4
Show file tree
Hide file tree
Showing 3 changed files with 157 additions and 38 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@
import com.mockrunner.jdbc.CallableStatementResultSetHandler;
import com.mockrunner.jdbc.PreparedStatementResultSetHandler;
import com.mockrunner.jdbc.StatementResultSetHandler;
import com.mockrunner.mock.jdbc.JDBCMockObjectFactory;
import com.mockrunner.mock.jdbc.MockConnection;
import com.mockrunner.mock.jdbc.MockDataSource;
import com.mockrunner.mock.jdbc.MockStatement;
import com.p6spy.engine.common.ConnectionInformation;
import com.p6spy.engine.logging.P6LogOptions;
import com.p6spy.engine.proxy.Delegate;
Expand All @@ -25,12 +26,14 @@

import javax.sql.DataSource;

import java.lang.reflect.Field;
import java.lang.reflect.Method;
import java.sql.Connection;
import java.sql.PreparedStatement;
import java.sql.SQLException;
import java.sql.Statement;
import java.util.*;
import java.util.function.Consumer;
import java.util.stream.Stream;

/**
Expand All @@ -56,46 +59,79 @@ public DataSource spyOnDataSource(DataSource ds) {
}

public DataSource createMockDataSource() {
JDBCMockObjectFactory jdbcMockObjectFactory = new JDBCMockObjectFactory() {
@Override
public void registerMockDriver() {
// we don't want to auto-hijack DriverManager
}

@Override
public void restoreDrivers() {
// we don't want to auto-hijack DriverManager
}
MockDataSource dataSource = new MockDataSource() {{
setupConnection(new StubbingMockConnection());
}};
return interceptDataSource(dataSource);
}

@Override
public MockConnection createMockConnection() {
// this is a hack, leveraging the fact that getSQLException() is the first
// method checked to determine the mock resultset behaviour.
return new MockConnection(
new StatementResultSetHandler() {
@Override
public SQLException getSQLException(String sql) {
throw new AssertionError("unmatched sql statement: '" + sql + "'");
}
}
, new PreparedStatementResultSetHandler() {
private static class StubbingMockConnection extends MockConnection {
public StubbingMockConnection() {
this(new SynchronizedStatementResultSetHandler()
, new PreparedStatementResultSetHandler() {
@Override
public SQLException getSQLException(String sql) {
throw new AssertionError("unmatched sql statement: '" + sql + "'");
}
}
, new CallableStatementResultSetHandler() {
, new CallableStatementResultSetHandler() {
@Override
public SQLException getSQLException(String sql) {
throw new AssertionError("unmatched sql statement: '" + sql + "'");
}
}
);
);
}

public StubbingMockConnection(StatementResultSetHandler statementHandler, PreparedStatementResultSetHandler preparedStatementHandler, CallableStatementResultSetHandler callableStatementHandler) {
super(synchronizeMembers(statementHandler), synchronizeMembers(preparedStatementHandler), synchronizeMembers(callableStatementHandler));
}

@SneakyThrows
private static <T> T synchronizeMembers(T o) {
doWithFields(o.getClass(), f->syncField(o, f));
for(Field f : o.getClass().getDeclaredFields()) {
syncField(o, f);
}
};
final MockConnection mockConnection = jdbcMockObjectFactory.getMockConnection();
jdbcMockObjectFactory.getMockDataSource().setupConnection(mockConnection);
return interceptDataSource(jdbcMockObjectFactory.getMockDataSource());
return o;
}

private static void doWithFields(Class<?> clazz, Consumer<Field> fc) {
// Keep backing up the inheritance hierarchy.
Class<?> targetClass = clazz;
do {
Field[] fields = targetClass.getDeclaredFields();
for (Field field : fields) {
field.setAccessible(true);
fc.accept(field);
}
targetClass = targetClass.getSuperclass();
}
while (targetClass != null && targetClass != Object.class);
}

@SneakyThrows
private static <T> void syncField(T o, Field f) {
Class<?> fieldType = f.getType();
Object value = f.get(o);
if (List.class.isAssignableFrom(fieldType) && value != null) {
f.set(o, Collections.synchronizedList((List<?>) value));
} else if (Map.class.isAssignableFrom(fieldType) && value != null) {
f.set(o, Collections.synchronizedMap((Map<?,?>)value));
}
}

private static class SynchronizedStatementResultSetHandler extends StatementResultSetHandler {
@Override
public SQLException getSQLException(String sql) {
throw new AssertionError("unmatched sql statement: '" + sql + "'");
}

@Override
public synchronized void addStatement(MockStatement statement) {
super.addStatement(statement);
}
}
}

@Override
Expand Down Expand Up @@ -142,7 +178,7 @@ protected Object interceptPreparedStatementExecution(PreparedStatementInformatio
try (CloseableHttpResponse response = httpclient.execute(httpPost)) {
if (response.getStatusLine().getStatusCode() == 200) {
String responseContent = EntityUtils.toString(response.getEntity(), "utf-8");
if(int[].class.equals(method.getReturnType())) {
if (int[].class.equals(method.getReturnType())) {
return parseBatchUpdateRowsAffected(responseContent);
}
if (int.class.equals(method.getReturnType())) {
Expand Down Expand Up @@ -203,17 +239,20 @@ protected Delegate createDataSourceGetConnectionDelegate() {

protected Delegate createConnectionPrepareStatementDelegate(final ConnectionInformation connectionInformation) {
return (final Object proxy, final Object underlying, final Method method, final Object[] args) -> {
PreparedStatement statement = (PreparedStatement) method.invoke(underlying, args);
String query = (String) args[0];
GenericInvocationHandler<PreparedStatement> invocationHandler = createPreparedStatementInvocationHandler(connectionInformation, statement, query);
return ProxyFactory.createProxy(statement, invocationHandler);
synchronized (this) {
PreparedStatement statement = (PreparedStatement) method.invoke(underlying, args);
String query = (String) args[0];
GenericInvocationHandler<PreparedStatement> invocationHandler = createPreparedStatementInvocationHandler(connectionInformation, statement, query);
return ProxyFactory.createProxy(statement, invocationHandler);
}
};
}

protected Delegate createPreparedStatementExecuteDelegate(final PreparedStatementInformation preparedStatementInformation) {
return (final Object proxy, final Object underlying, final Method method, final Object[] args) -> {

return interceptPreparedStatementExecution(preparedStatementInformation, underlying, method, args);
synchronized (preparedStatementInformation.getConnectionInformation()) {
return interceptPreparedStatementExecution(preparedStatementInformation, underlying, method, args);
}
};
}

Expand Down Expand Up @@ -327,7 +366,7 @@ protected P6MockPreparedStatementInvocationHandler createPreparedStatementInvoca
* @return array with corresponding number of updated rows for each batch
*/
private static int[] parseBatchUpdateRowsAffected(String responseContent) {
return Stream.of(responseContent.split(",")).mapToInt(s->Integer.parseInt(s)).toArray();
return Stream.of(responseContent.split(",")).mapToInt(s -> Integer.parseInt(s)).toArray();
}

}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
package org.eeichinger.servicevirtualisation.jdbc;

import java.sql.Connection;
import java.sql.PreparedStatement;
import java.util.ArrayList;
import java.util.List;
import java.util.Random;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;

import javax.sql.DataSource;

import com.github.tomakehurst.wiremock.client.WireMock;
import com.github.tomakehurst.wiremock.junit.WireMockRule;
import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.ExpectedException;
import org.springframework.jdbc.core.JdbcTemplate;

import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.equalTo;

/**
* @author Erich Eichinger
* @since 22/07/16
*/
public class MockDataSourceThreadSafetyTest {

@Rule
public WireMockRule wireMockRule = new WireMockRule(0);

@Rule
public ExpectedException thrown = ExpectedException.none();

DataSource dataSource;

@Before
public void before() {
JdbcServiceVirtualizationFactory myP6MockFactory = new JdbcServiceVirtualizationFactory();
myP6MockFactory.setTargetUrl("http://localhost:" + wireMockRule.port() + "/sqlstub");

dataSource = myP6MockFactory.createMockDataSource();
}

@Test
public void intercepts_matching_preparedstatement_and_responds_with_mockresultset() throws Throwable {
final Connection connection = dataSource.getConnection();

final int count = 20000;
final ExecutorService executor = Executors.newFixedThreadPool(1000);

Random random = new Random(System.currentTimeMillis());

ArrayList<Callable<Throwable>> tasks = new ArrayList<>();
for(int i=0; i<count;i++) {
final String sql = "SELECT birthday FROM PEOPLE WHERE name = ? AND " + i + "=" + i;
tasks.add(() -> {
try {
Thread.sleep(random.nextInt(50));
final PreparedStatement ps = connection.prepareStatement(sql);
return null;
} catch (Throwable e) {
return e;
}
});
}

final List<Future<Throwable>> results = executor.invokeAll(tasks);
for(int i=0;i<count;i++) {
final Throwable caughtException = results.get(i).get();
if (caughtException != null) {
throw caughtException;
}
}
}
}
3 changes: 2 additions & 1 deletion src/test/resources/logback-test.xml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
</encoder>
</appender>
<!--<logger level="DEBUG" name="org.apache.http.wire"/>-->
<root level="DEBUG">
<logger level="DEBUG" name="org.eeichinger"/>
<root level="INFO">
<appender-ref ref="STDOUT"/>
</root>
</configuration>

0 comments on commit d5cf8e4

Please sign in to comment.