diff --git a/sharding-core/src/main/java/io/shardingjdbc/core/parsing/SQLParsingEngine.java b/sharding-core/src/main/java/io/shardingjdbc/core/parsing/SQLParsingEngine.java index 5075a9849895d..4e3eb69aaeff0 100755 --- a/sharding-core/src/main/java/io/shardingjdbc/core/parsing/SQLParsingEngine.java +++ b/sharding-core/src/main/java/io/shardingjdbc/core/parsing/SQLParsingEngine.java @@ -23,6 +23,8 @@ import io.shardingjdbc.core.parsing.lexer.LexerEngineFactory; import io.shardingjdbc.core.parsing.parser.sql.SQLParserFactory; import io.shardingjdbc.core.parsing.parser.sql.SQLStatement; +import io.shardingjdbc.core.parsing.parser.token.GeneratedKeyToken; +import io.shardingjdbc.core.parsing.parser.token.SQLToken; import io.shardingjdbc.core.rule.ShardingRule; import lombok.RequiredArgsConstructor; @@ -52,6 +54,20 @@ public SQLStatement parse() { } LexerEngine lexerEngine = LexerEngineFactory.newInstance(dbType, sql); lexerEngine.nextToken(); - return SQLParserFactory.newInstance(dbType, lexerEngine.getCurrentToken().getType(), shardingRule, lexerEngine).parse(); + SQLStatement result = SQLParserFactory.newInstance(dbType, lexerEngine.getCurrentToken().getType(), shardingRule, lexerEngine).parse(); + // TODO cannot cache InsertStatement here by generate key, should not modify original InsertStatement on router. + if (!findGeneratedKeyToken(result)) { + ParsingResultCache.getInstance().put(sql, result); + } + return result; + } + + private boolean findGeneratedKeyToken(final SQLStatement sqlStatement) { + for (SQLToken each : sqlStatement.getSqlTokens()) { + if (each instanceof GeneratedKeyToken) { + return true; + } + } + return false; } } diff --git a/sharding-core/src/main/java/io/shardingjdbc/core/parsing/cache/ParsingResultCache.java b/sharding-core/src/main/java/io/shardingjdbc/core/parsing/cache/ParsingResultCache.java index 9bc28925752f4..430c554ab1ce5 100644 --- a/sharding-core/src/main/java/io/shardingjdbc/core/parsing/cache/ParsingResultCache.java +++ b/sharding-core/src/main/java/io/shardingjdbc/core/parsing/cache/ParsingResultCache.java @@ -64,4 +64,11 @@ public void put(final String sql, final SQLStatement sqlStatement) { public SQLStatement getSQLStatement(final String sql) { return cache.get(sql); } + + /** + * Clear cache. + */ + public synchronized void clear() { + cache.clear(); + } } diff --git a/sharding-jdbc-core/src/test/java/io/shardingjdbc/core/common/base/AbstractSQLAssertTest.java b/sharding-jdbc-core/src/test/java/io/shardingjdbc/core/common/base/AbstractSQLAssertTest.java index 92350a8b89c06..21eb9e27fb583 100644 --- a/sharding-jdbc-core/src/test/java/io/shardingjdbc/core/common/base/AbstractSQLAssertTest.java +++ b/sharding-jdbc-core/src/test/java/io/shardingjdbc/core/common/base/AbstractSQLAssertTest.java @@ -22,7 +22,6 @@ import io.shardingjdbc.core.common.env.DatabaseEnvironment; import io.shardingjdbc.core.common.env.ShardingJdbcDatabaseTester; import io.shardingjdbc.core.common.env.ShardingTestStrategy; -import io.shardingjdbc.core.util.SQLAssertHelper; import io.shardingjdbc.core.constant.DatabaseType; import io.shardingjdbc.core.constant.SQLType; import io.shardingjdbc.core.integrate.jaxb.SQLAssertData; @@ -30,6 +29,8 @@ import io.shardingjdbc.core.jdbc.adapter.AbstractDataSourceAdapter; import io.shardingjdbc.core.jdbc.core.datasource.MasterSlaveDataSource; import io.shardingjdbc.core.jdbc.core.datasource.ShardingDataSource; +import io.shardingjdbc.core.parsing.cache.ParsingResultCache; +import io.shardingjdbc.core.util.SQLAssertHelper; import lombok.Getter; import org.dbunit.DatabaseUnitException; import org.dbunit.IDatabaseTester; @@ -135,6 +136,11 @@ public void cleanupDDLTables() throws SQLException { } } + @After + public void cleanupParsingResultCache() { + ParsingResultCache.getInstance().clear(); + } + private void executeSQL(final String sql) throws SQLException { for (Map.Entry each : getDataSources().entrySet()) { if (getCurrentDatabaseType() == each.getKey()) { diff --git a/sharding-jdbc-core/src/test/java/io/shardingjdbc/core/jdbc/core/statement/ShardingStatementTest.java b/sharding-jdbc-core/src/test/java/io/shardingjdbc/core/jdbc/core/statement/ShardingStatementTest.java index 7a2f70e1ffb6f..f8a545dbfaf41 100644 --- a/sharding-jdbc-core/src/test/java/io/shardingjdbc/core/jdbc/core/statement/ShardingStatementTest.java +++ b/sharding-jdbc-core/src/test/java/io/shardingjdbc/core/jdbc/core/statement/ShardingStatementTest.java @@ -32,17 +32,17 @@ import static org.junit.Assert.assertTrue; public final class ShardingStatementTest extends AbstractShardingJDBCDatabaseAndTableTest { - + private String sql = "SELECT COUNT(*) AS orders_count FROM t_order WHERE status = 'init'"; - + private String sql2 = "DELETE FROM t_order WHERE status ='init'"; - + private String sql3 = "INSERT INTO t_order_item(order_id, user_id, status) VALUES (%d, %d, '%s')"; - + public ShardingStatementTest(final DatabaseType databaseType) { super(databaseType); } - + @Test public void assertExecuteQuery() throws SQLException { try ( @@ -53,7 +53,7 @@ public void assertExecuteQuery() throws SQLException { assertThat(resultSet.getLong(1), is(4L)); } } - + @Test public void assertExecuteUpdate() throws SQLException { try ( @@ -62,7 +62,7 @@ public void assertExecuteUpdate() throws SQLException { assertThat(stmt.executeUpdate(sql2), is(4)); } } - + @Test public void assertExecute() throws SQLException { try ( @@ -73,7 +73,7 @@ public void assertExecute() throws SQLException { assertThat(stmt.getResultSet().getLong(1), is(4L)); } } - + @Test public void assertExecuteQueryWithResultSetTypeAndResultSetConcurrency() throws SQLException { try ( @@ -84,7 +84,7 @@ public void assertExecuteQueryWithResultSetTypeAndResultSetConcurrency() throws assertThat(resultSet.getLong(1), is(4L)); } } - + @Test public void assertExecuteQueryWithResultSetTypeAndResultSetConcurrencyAndResultSetHoldability() throws SQLException { try ( @@ -95,7 +95,7 @@ public void assertExecuteQueryWithResultSetTypeAndResultSetConcurrencyAndResultS assertThat(resultSet.getLong(1), is(4L)); } } - + @Test public void assertExecuteUpdateWithAutoGeneratedKeys() throws SQLException { try ( @@ -104,7 +104,7 @@ public void assertExecuteUpdateWithAutoGeneratedKeys() throws SQLException { assertThat(stmt.executeUpdate(sql2, Statement.NO_GENERATED_KEYS), is(4)); } } - + @Test public void assertExecuteUpdateWithColumnIndexes() throws SQLException { if (DatabaseType.PostgreSQL != getCurrentDatabaseType()) { @@ -115,7 +115,7 @@ public void assertExecuteUpdateWithColumnIndexes() throws SQLException { } } } - + @Test public void assertExecuteUpdateWithColumnNames() throws SQLException { if (DatabaseType.H2 == getCurrentDatabaseType() || DatabaseType.MySQL == getCurrentDatabaseType()) { @@ -126,7 +126,7 @@ public void assertExecuteUpdateWithColumnNames() throws SQLException { } } } - + @Test public void assertExecuteWithAutoGeneratedKeys() throws SQLException { try ( @@ -137,7 +137,7 @@ public void assertExecuteWithAutoGeneratedKeys() throws SQLException { assertThat(stmt.getResultSet().getLong(1), is(4L)); } } - + @Test public void assertExecuteWithColumnIndexes() throws SQLException { if (DatabaseType.PostgreSQL != getCurrentDatabaseType()) { @@ -150,7 +150,7 @@ public void assertExecuteWithColumnIndexes() throws SQLException { } } } - + @Test public void assertExecuteWithColumnNames() throws SQLException { if (DatabaseType.PostgreSQL != getCurrentDatabaseType()) { @@ -163,7 +163,7 @@ public void assertExecuteWithColumnNames() throws SQLException { } } } - + @Test public void assertGetConnection() throws SQLException { try ( @@ -172,7 +172,7 @@ public void assertGetConnection() throws SQLException { assertThat(stmt.getConnection(), is(connection)); } } - + @Test public void assertGetGeneratedKeys() throws SQLException { if (DatabaseType.PostgreSQL != getCurrentDatabaseType()) {