Skip to content

Commit

Permalink
feat(open-mysql-db): refactor mock code (#3831)
Browse files Browse the repository at this point in the history
* feat(open-mysql-db): refactor

1. remove unnecessary instance var port
2. fix cause null bug
3. remove unnecessary throws
4. fix ctx.close() sequence bug
5. config sessionTimeout and requestTimeout
6. add docs of SqlEngine

* feat(open-mysql-db): refactor

* feat(open-mysql-db): revert passsword

---------

Co-authored-by: yangwucheng <[email protected]>
  • Loading branch information
yangwucheng and yangwucheng authored Apr 12, 2024
1 parent 7f758af commit 1418350
Show file tree
Hide file tree
Showing 6 changed files with 669 additions and 608 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -41,18 +41,21 @@ public class MySqlListener implements AutoCloseable {

public static final String VERSION = "8.0.29";
public static final String VERSION_COMMENT = "";
public static final String CHARACTER_SET_UTF8MB4 = "utf8mb4";
public static final String COLLATION_UTF8MB4_0900_AI_CI = "utf8mb4_0900_ai_ci";
public static final String SETTINGS_LOWER_CASE_TABLE_NAMES = "2";
public static final String SETTINGS_INTERACTIVE_TIMEOUT = "28800";
public static final String SETTINGS_WAIT_TIMEOUT = "28800";
private static final Pattern SETTINGS_PATTERN =
Pattern.compile("@@([\\w.]+)(?:\\sAS\\s)?(\\w+)?");
private static final Pattern USE_DB_PATTERN = Pattern.compile("(?i)use (.+)");
private final SqlEngine sqlEngine;
private final int port;
private final Channel channel;
private final io.netty.channel.EventLoopGroup parentGroup;
private final EventLoopGroup childGroup;
private final EventExecutorGroup eventExecutorGroup;

public MySqlListener(int port, int executorGroupSize, SqlEngine sqlEngine) {
this.port = port;
this.sqlEngine = sqlEngine;

parentGroup = new NioEventLoopGroup();
Expand Down Expand Up @@ -87,7 +90,7 @@ public MySqlListener(int port, int executorGroupSize, SqlEngine sqlEngine) {
.childHandler(
new ChannelInitializer<NioSocketChannel>() {
@Override
protected void initChannel(NioSocketChannel ch) throws Exception {
protected void initChannel(NioSocketChannel ch) {
System.out.println("[mysql-protocol] Initializing child channel");
final ChannelPipeline pipeline = ch.pipeline();
pipeline.addLast(new MysqlServerPacketEncoder());
Expand Down Expand Up @@ -160,14 +163,21 @@ private void handleHandshakeResponse(
Throwable cause = e.getCause();
int errorCode;
byte[] sqlState;
String errMsg =
Utils.getLocalDateTimeNow()
+ " "
+ Objects.requireNonNullElse(cause.getMessage(), e.getMessage());
if (cause instanceof IllegalAccessException) {
errorCode = 1045;
sqlState = "#28000".getBytes(StandardCharsets.US_ASCII);
String errMsg;
if (cause != null) {
errMsg =
Utils.getLocalDateTimeNow()
+ " "
+ Objects.requireNonNullElse(cause.getMessage(), e.getMessage());
if (cause instanceof IllegalAccessException) {
errorCode = 1045;
sqlState = "#28000".getBytes(StandardCharsets.US_ASCII);
} else {
errorCode = 1105;
sqlState = "#HY000".getBytes(StandardCharsets.US_ASCII);
}
} else {
errMsg = Utils.getLocalDateTimeNow() + " " + Objects.requireNonNullElse(e.getMessage(), "");
errorCode = 1105;
sqlState = "#HY000".getBytes(StandardCharsets.US_ASCII);
}
Expand Down Expand Up @@ -197,17 +207,14 @@ private void handleQuery(
+ userName
+ ", scramble411: "
+ scramble411.length);
Matcher useDbMatcher =
USE_DB_PATTERN.matcher(queryString.replaceAll("/\\*.*\\*/", "").toLowerCase().trim());
String queryStringWithoutComment =
queryString.replaceAll("/\\*.*\\*/", "").toLowerCase().trim();
Matcher useDbMatcher = USE_DB_PATTERN.matcher(queryStringWithoutComment);

if (isServerSettingsQuery(queryString)) {
sendSettingsResponse(ctx, query, remoteAddr);
} else if (queryString.replaceAll("/\\*.*\\*/", "").toLowerCase().trim().startsWith("set ")
&& !queryString
.replaceAll("/\\*.*\\*/", "")
.toLowerCase()
.trim()
.startsWith("set @@execute_mode=")) {
} else if (queryStringWithoutComment.startsWith("set ")
&& !queryStringWithoutComment.startsWith("set @@execute_mode=")) {
// ignore SET command
ctx.writeAndFlush(OkResponse.builder().sequenceId(query.getSequenceId() + 1).build());
} else if (useDbMatcher.matches()) {
Expand All @@ -218,12 +225,9 @@ private void handleQuery(
} else {
// Generic response
int[] sequenceId = new int[] {query.getSequenceId()};

boolean[] columnsWritten = new boolean[1];

ResultSetWriter resultSetWriter =
new ResultSetWriter() {

@Override
public void writeColumns(List<QueryResultColumn> columns) {
ctx.write(new ColumnCount(++sequenceId[0], columns.size()));
Expand Down Expand Up @@ -272,9 +276,7 @@ public void writeColumns(List<QueryResultColumn> columns) {
.build());
}
ctx.write(new EofResponse(++sequenceId[0], 0));

System.out.println("[mysql-protocol] Columns done");

columnsWritten[0] = !columns.isEmpty();
}

Expand All @@ -290,7 +292,6 @@ public void writeRow(List<String> row) {
@Override
public void finish() {
ctx.writeAndFlush(new EofResponse(++sequenceId[0], 0));

System.out.println("[mysql-protocol] All done");
}
};
Expand All @@ -311,22 +312,29 @@ public void finish() {
Throwable cause = e.getCause();
int errorCode;
byte[] sqlState;
String errMsg =
Utils.getLocalDateTimeNow()
+ " "
+ Objects.requireNonNullElse(cause.getMessage(), e.getMessage());
if (cause instanceof IllegalAccessException) {
errorCode = 1045;
sqlState = "#28000".getBytes(StandardCharsets.US_ASCII);
} else if (cause instanceof IllegalArgumentException) {
errorCode = 1064;
sqlState = "#42000".getBytes(StandardCharsets.US_ASCII);
} else if (e.getMessage()
.equalsIgnoreCase(
"java.sql.SQLException: executeSQL fail: [2000] please enter database first")) {
errorCode = 1046;
sqlState = "#3D000".getBytes(StandardCharsets.US_ASCII);
String errMsg;
if (cause != null) {
errMsg =
Utils.getLocalDateTimeNow()
+ " "
+ Objects.requireNonNullElse(cause.getMessage(), e.getMessage());
if (cause instanceof IllegalAccessException) {
errorCode = 1045;
sqlState = "#28000".getBytes(StandardCharsets.US_ASCII);
} else if (cause instanceof IllegalArgumentException) {
errorCode = 1064;
sqlState = "#42000".getBytes(StandardCharsets.US_ASCII);
} else if (e.getMessage()
.equalsIgnoreCase(
"java.sql.SQLException: executeSQL fail: [2000] please enter database first")) {
errorCode = 1046;
sqlState = "#3D000".getBytes(StandardCharsets.US_ASCII);
} else {
errorCode = 1105;
sqlState = "#HY000".getBytes(StandardCharsets.US_ASCII);
}
} else {
errMsg = Utils.getLocalDateTimeNow() + " " + e.getMessage();
errorCode = 1105;
sqlState = "#HY000".getBytes(StandardCharsets.US_ASCII);
}
Expand Down Expand Up @@ -435,16 +443,18 @@ private void sendSettingsResponse(
columnDefinitions.add(
newColumnDefinition(
++sequenceId, fieldName, systemVariable, ColumnType.MYSQL_TYPE_VAR_STRING, 12));
values.add("utf8mb4");
values.add(CHARACTER_SET_UTF8MB4);
break;
case "collation_server":
case "GLOBAL.collation_server":
case "collation_connection":
columnDefinitions.add(
newColumnDefinition(
++sequenceId, fieldName, systemVariable, ColumnType.MYSQL_TYPE_VAR_STRING, 63));
values.add("utf8mb4_0900_ai_ci");
values.add(COLLATION_UTF8MB4_0900_AI_CI);
break;
case "init_connect":
case "language":
columnDefinitions.add(
newColumnDefinition(
++sequenceId, fieldName, systemVariable, ColumnType.MYSQL_TYPE_VAR_STRING, 0));
Expand All @@ -454,13 +464,7 @@ private void sendSettingsResponse(
columnDefinitions.add(
newColumnDefinition(
++sequenceId, fieldName, systemVariable, ColumnType.MYSQL_TYPE_VAR_STRING, 21));
values.add("28800");
break;
case "language":
columnDefinitions.add(
newColumnDefinition(
++sequenceId, fieldName, systemVariable, ColumnType.MYSQL_TYPE_VAR_STRING, 0));
values.add("");
values.add(SETTINGS_INTERACTIVE_TIMEOUT);
break;
case "license":
columnDefinitions.add(
Expand All @@ -472,7 +476,7 @@ private void sendSettingsResponse(
columnDefinitions.add(
newColumnDefinition(
++sequenceId, fieldName, systemVariable, ColumnType.MYSQL_TYPE_LONGLONG, 63));
values.add("2");
values.add(SETTINGS_LOWER_CASE_TABLE_NAMES);
break;
case "max_allowed_packet":
case "global.max_allowed_packet":
Expand Down Expand Up @@ -528,7 +532,7 @@ private void sendSettingsResponse(
columnDefinitions.add(
newColumnDefinition(
++sequenceId, fieldName, systemVariable, ColumnType.MYSQL_TYPE_LONGLONG, 12));
values.add("28800");
values.add(SETTINGS_WAIT_TIMEOUT);
break;
case "query_cache_type":
columnDefinitions.add(
Expand All @@ -542,12 +546,6 @@ private void sendSettingsResponse(
++sequenceId, fieldName, systemVariable, ColumnType.MYSQL_TYPE_VAR_STRING, 0));
values.add(VERSION_COMMENT);
break;
case "collation_connection":
columnDefinitions.add(
newColumnDefinition(
++sequenceId, fieldName, systemVariable, ColumnType.MYSQL_TYPE_VAR_STRING, 63));
values.add("utf8mb4_0900_ai_ci");
break;
case "query_cache_size":
columnDefinitions.add(
newColumnDefinition(
Expand Down Expand Up @@ -578,14 +576,12 @@ private void sendSettingsResponse(
newColumnDefinition(
++sequenceId, fieldName, systemVariable, ColumnType.MYSQL_TYPE_VAR_STRING, 63));
values.add("REPEATABLE-READ");
// values.add("READ-UNCOMMITTED");
break;
case "session.transaction_read_only":
columnDefinitions.add(
newColumnDefinition(
++sequenceId, fieldName, systemVariable, ColumnType.MYSQL_TYPE_TINY, 1));
values.add("0");
// values.add("READ-UNCOMMITTED");
break;
default:
System.err.println("[mysql-protocol] Unknown system variable: " + systemVariable);
Expand Down Expand Up @@ -632,7 +628,7 @@ public ServerHandler() {
}

@Override
public void channelActive(ChannelHandlerContext ctx) throws Exception {
public void channelActive(ChannelHandlerContext ctx) {
// todo may java.lang.NullPointerException
this.remoteAddr =
((InetSocketAddress) ctx.channel().remoteAddress()).getAddress().getHostAddress();
Expand All @@ -650,7 +646,7 @@ public void channelActive(ChannelHandlerContext ctx) throws Exception {
}

@Override
public void channelInactive(ChannelHandlerContext ctx) throws Exception {
public void channelInactive(ChannelHandlerContext ctx) {
System.out.println("[mysql-protocol] Server channel inactive: " + new Date());
sqlEngine.close(getConnectionId(ctx));
}
Expand Down Expand Up @@ -682,6 +678,8 @@ public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception
} else if (command.equals(Command.COM_PING)) {
ctx.writeAndFlush(OkResponse.builder().sequenceId(sequenceId + 1).build());
} else if (command.equals(Command.COM_FIELD_LIST)) {
// ToDo:
// https://dev.mysql.com/doc/dev/mysql-server/8.0.34/page_protocol_com_field_list.html
ctx.writeAndFlush(new EofResponse(sequenceId + 1, 0));
} else if (command.equals(Command.COM_STATISTICS)) {
String statString =
Expand All @@ -696,10 +694,10 @@ public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception
}

@Override
public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) {
cause.printStackTrace();
ctx.close();
sqlEngine.close(getConnectionId(ctx));
ctx.close();
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,22 @@

/** An interface to callback events received from the MySQL server. */
public interface SqlEngine {
/**
* Execute query use database
*
* @param connectionId Connection id
* @param database Database name
* @throws IOException Thrown with SQLTimeoutException as the inner cause if when the driver has
* determined that the timeout value that was specified by the setQueryTimeout method has been
* exceeded and has at least attempted to cancel the currently running Statement, or
* SQLException as the inner cause if a database access error occurs.
*/
void useDatabase(int connectionId, String database) throws IOException;

/**
* Authenticating the user and password.
*
* @param connectionId Connection id
* @param database Database name
* @param userName User name
* @param scramble411 Encoded password
Expand All @@ -40,6 +51,7 @@ void authenticate(
/**
* Querying the SQL.
*
* @param connectionId Connection id
* @param resultSetWriter Response writer
* @param database Database name
* @param userName User name
Expand All @@ -59,5 +71,10 @@ void query(
String sql)
throws IOException;

void close(int connectionId) throws IOException;
/**
* Close resources of connection
*
* @param connectionId Connection id
*/
void close(int connectionId);
}
Loading

0 comments on commit 1418350

Please sign in to comment.