Skip to content

Commit

Permalink
for #701, put sql in cache
Browse files Browse the repository at this point in the history
  • Loading branch information
terrymanu committed Apr 8, 2018
1 parent f8dedd5 commit d6bb344
Show file tree
Hide file tree
Showing 4 changed files with 48 additions and 19 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
import io.shardingjdbc.core.parsing.lexer.LexerEngineFactory;
import io.shardingjdbc.core.parsing.parser.sql.SQLParserFactory;
import io.shardingjdbc.core.parsing.parser.sql.SQLStatement;
import io.shardingjdbc.core.parsing.parser.token.GeneratedKeyToken;
import io.shardingjdbc.core.parsing.parser.token.SQLToken;
import io.shardingjdbc.core.rule.ShardingRule;
import lombok.RequiredArgsConstructor;

Expand Down Expand Up @@ -52,6 +54,20 @@ public SQLStatement parse() {
}
LexerEngine lexerEngine = LexerEngineFactory.newInstance(dbType, sql);
lexerEngine.nextToken();
return SQLParserFactory.newInstance(dbType, lexerEngine.getCurrentToken().getType(), shardingRule, lexerEngine).parse();
SQLStatement result = SQLParserFactory.newInstance(dbType, lexerEngine.getCurrentToken().getType(), shardingRule, lexerEngine).parse();
// TODO cannot cache InsertStatement here by generate key, should not modify original InsertStatement on router.
if (!findGeneratedKeyToken(result)) {
ParsingResultCache.getInstance().put(sql, result);
}
return result;
}

private boolean findGeneratedKeyToken(final SQLStatement sqlStatement) {
for (SQLToken each : sqlStatement.getSqlTokens()) {
if (each instanceof GeneratedKeyToken) {
return true;
}
}
return false;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -64,4 +64,11 @@ public void put(final String sql, final SQLStatement sqlStatement) {
public SQLStatement getSQLStatement(final String sql) {
return cache.get(sql);
}

/**
* Clear cache.
*/
public synchronized void clear() {
cache.clear();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,15 @@
import io.shardingjdbc.core.common.env.DatabaseEnvironment;
import io.shardingjdbc.core.common.env.ShardingJdbcDatabaseTester;
import io.shardingjdbc.core.common.env.ShardingTestStrategy;
import io.shardingjdbc.core.util.SQLAssertHelper;
import io.shardingjdbc.core.constant.DatabaseType;
import io.shardingjdbc.core.constant.SQLType;
import io.shardingjdbc.core.integrate.jaxb.SQLAssertData;
import io.shardingjdbc.core.integrate.jaxb.SQLShardingRule;
import io.shardingjdbc.core.jdbc.adapter.AbstractDataSourceAdapter;
import io.shardingjdbc.core.jdbc.core.datasource.MasterSlaveDataSource;
import io.shardingjdbc.core.jdbc.core.datasource.ShardingDataSource;
import io.shardingjdbc.core.parsing.cache.ParsingResultCache;
import io.shardingjdbc.core.util.SQLAssertHelper;
import lombok.Getter;
import org.dbunit.DatabaseUnitException;
import org.dbunit.IDatabaseTester;
Expand Down Expand Up @@ -135,6 +136,11 @@ public void cleanupDDLTables() throws SQLException {
}
}

@After
public void cleanupParsingResultCache() {
ParsingResultCache.getInstance().clear();
}

private void executeSQL(final String sql) throws SQLException {
for (Map.Entry<DatabaseType, ? extends AbstractDataSourceAdapter> each : getDataSources().entrySet()) {
if (getCurrentDatabaseType() == each.getKey()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,17 +32,17 @@
import static org.junit.Assert.assertTrue;

public final class ShardingStatementTest extends AbstractShardingJDBCDatabaseAndTableTest {

private String sql = "SELECT COUNT(*) AS orders_count FROM t_order WHERE status = 'init'";

private String sql2 = "DELETE FROM t_order WHERE status ='init'";

private String sql3 = "INSERT INTO t_order_item(order_id, user_id, status) VALUES (%d, %d, '%s')";

public ShardingStatementTest(final DatabaseType databaseType) {
super(databaseType);
}

@Test
public void assertExecuteQuery() throws SQLException {
try (
Expand All @@ -53,7 +53,7 @@ public void assertExecuteQuery() throws SQLException {
assertThat(resultSet.getLong(1), is(4L));
}
}

@Test
public void assertExecuteUpdate() throws SQLException {
try (
Expand All @@ -62,7 +62,7 @@ public void assertExecuteUpdate() throws SQLException {
assertThat(stmt.executeUpdate(sql2), is(4));
}
}

@Test
public void assertExecute() throws SQLException {
try (
Expand All @@ -73,7 +73,7 @@ public void assertExecute() throws SQLException {
assertThat(stmt.getResultSet().getLong(1), is(4L));
}
}

@Test
public void assertExecuteQueryWithResultSetTypeAndResultSetConcurrency() throws SQLException {
try (
Expand All @@ -84,7 +84,7 @@ public void assertExecuteQueryWithResultSetTypeAndResultSetConcurrency() throws
assertThat(resultSet.getLong(1), is(4L));
}
}

@Test
public void assertExecuteQueryWithResultSetTypeAndResultSetConcurrencyAndResultSetHoldability() throws SQLException {
try (
Expand All @@ -95,7 +95,7 @@ public void assertExecuteQueryWithResultSetTypeAndResultSetConcurrencyAndResultS
assertThat(resultSet.getLong(1), is(4L));
}
}

@Test
public void assertExecuteUpdateWithAutoGeneratedKeys() throws SQLException {
try (
Expand All @@ -104,7 +104,7 @@ public void assertExecuteUpdateWithAutoGeneratedKeys() throws SQLException {
assertThat(stmt.executeUpdate(sql2, Statement.NO_GENERATED_KEYS), is(4));
}
}

@Test
public void assertExecuteUpdateWithColumnIndexes() throws SQLException {
if (DatabaseType.PostgreSQL != getCurrentDatabaseType()) {
Expand All @@ -115,7 +115,7 @@ public void assertExecuteUpdateWithColumnIndexes() throws SQLException {
}
}
}

@Test
public void assertExecuteUpdateWithColumnNames() throws SQLException {
if (DatabaseType.H2 == getCurrentDatabaseType() || DatabaseType.MySQL == getCurrentDatabaseType()) {
Expand All @@ -126,7 +126,7 @@ public void assertExecuteUpdateWithColumnNames() throws SQLException {
}
}
}

@Test
public void assertExecuteWithAutoGeneratedKeys() throws SQLException {
try (
Expand All @@ -137,7 +137,7 @@ public void assertExecuteWithAutoGeneratedKeys() throws SQLException {
assertThat(stmt.getResultSet().getLong(1), is(4L));
}
}

@Test
public void assertExecuteWithColumnIndexes() throws SQLException {
if (DatabaseType.PostgreSQL != getCurrentDatabaseType()) {
Expand All @@ -150,7 +150,7 @@ public void assertExecuteWithColumnIndexes() throws SQLException {
}
}
}

@Test
public void assertExecuteWithColumnNames() throws SQLException {
if (DatabaseType.PostgreSQL != getCurrentDatabaseType()) {
Expand All @@ -163,7 +163,7 @@ public void assertExecuteWithColumnNames() throws SQLException {
}
}
}

@Test
public void assertGetConnection() throws SQLException {
try (
Expand All @@ -172,7 +172,7 @@ public void assertGetConnection() throws SQLException {
assertThat(stmt.getConnection(), is(connection));
}
}

@Test
public void assertGetGeneratedKeys() throws SQLException {
if (DatabaseType.PostgreSQL != getCurrentDatabaseType()) {
Expand Down

0 comments on commit d6bb344

Please sign in to comment.