Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support custom order in insert stmt #2075

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .github/workflows/cicd.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,11 @@ jobs:
run: |
make test

- name: run sql_router_test
id: sql_router_test
run: |
bash steps/ut.sh sql_router_test 0

- name: run sql_sdk_test
id: sql_sdk_test
run: |
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
/*
* Copyright 2021 4Paradigm
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com._4paradigm.openmldb.common;

import java.io.Serializable;

public class Pair<K, V> implements Serializable {

/**
* Key of this <code>Pair</code>.
*/
private K key;

/**
* Gets the key for this pair.
*
* @return key for this pair
*/
public K getKey() {
return key;
}

/**
* Value of this this <code>Pair</code>.
*/
private V value;

/**
* Gets the value for this pair.
*
* @return value for this pair
*/
public V getValue() {
return value;
}

/**
* Creates a new pair
*
* @param key The key for this pair
* @param value The value to use for this pair
*/
public Pair(K key, V value) {
this.key = key;
this.value = value;
}

/**
* <p><code>String</code> representation of this
* <code>Pair</code>.</p>
*
* <p>The default name/value delimiter '=' is always used.</p>
*
* @return <code>String</code> representation of this <code>Pair</code>
*/
@Override
public String toString() {
return key + "=" + value;
}

/**
* <p>Generate a hash code for this <code>Pair</code>.</p>
*
* <p>The hash code is calculated using both the name and
* the value of the <code>Pair</code>.</p>
*
* @return hash code for this <code>Pair</code>
*/
@Override
public int hashCode() {
// name's hashCode is multiplied by an arbitrary prime number (13)
// in order to make sure there is a difference in the hashCode between
// these two parameters:
// name: a value: aa
// name: aa value: a
return key.hashCode() * 13 + (value == null ? 0 : value.hashCode());
}

/**
* <p>Test this <code>Pair</code> for equality with another
* <code>Object</code>.</p>
*
* <p>If the <code>Object</code> to be tested is not a
* <code>Pair</code> or is <code>null</code>, then this method
* returns <code>false</code>.</p>
*
* <p>Two <code>Pair</code>s are considered equal if and only if
* both the names and values are equal.</p>
*
* @param o the <code>Object</code> to test for
* equality with this <code>Pair</code>
* @return <code>true</code> if the given <code>Object</code> is
* equal to this <code>Pair</code> else <code>false</code>
*/
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o instanceof Pair) {
Pair pair = (Pair) o;
if (key != null ? !key.equals(pair.key) : pair.key != null) return false;
if (value != null ? !value.equals(pair.value) : pair.value != null) return false;
return true;
}
return false;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,11 @@

package com._4paradigm.openmldb.jdbc;

import static com._4paradigm.openmldb.sdk.impl.Util.sqlTypeToString;

import com._4paradigm.openmldb.DataType;
import com._4paradigm.openmldb.Schema;
import com._4paradigm.openmldb.common.Pair;
import com._4paradigm.openmldb.sdk.Common;

import java.sql.ResultSetMetaData;
Expand All @@ -28,10 +31,11 @@ public class SQLInsertMetaData implements ResultSetMetaData {

private final List<DataType> schema;
private final Schema realSchema;
private final List<Integer> idx;
private final List<Pair<Long, Integer>> idx;

public SQLInsertMetaData(List<DataType> schema,
Schema realSchema,
List<Integer> idx) {
List<Pair<Long, Integer>> idx) {
this.schema = schema;
this.realSchema = realSchema;
this.idx = idx;
Expand Down Expand Up @@ -90,7 +94,7 @@ public boolean isCurrency(int i) throws SQLException {
@Override
public int isNullable(int i) throws SQLException {
check(i);
int index = idx.get(i - 1);
Long index = idx.get(i - 1).getKey();
if (realSchema.IsColumnNotNull(index)) {
return columnNoNulls;
} else {
Expand Down Expand Up @@ -119,7 +123,7 @@ public String getColumnLabel(int i) throws SQLException {
@Override
public String getColumnName(int i) throws SQLException {
check(i);
int index = idx.get(i - 1);
Long index = idx.get(i - 1).getKey();
return realSchema.GetColumnName(index);
}

Expand Down Expand Up @@ -156,14 +160,13 @@ public String getCatalogName(int i) throws SQLException {
@Override
public int getColumnType(int i) throws SQLException {
check(i);
DataType dataType = schema.get(i - 1);
return Common.type2SqlType(dataType);
Long index = idx.get(i - 1).getKey();
return Common.type2SqlType(realSchema.GetColumnType(index));
}

@Override
@Deprecated
public String getColumnTypeName(int i) throws SQLException {
throw new SQLException("current do not support this method");
return sqlTypeToString(getColumnType(i));
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import com._4paradigm.openmldb.*;

import com._4paradigm.openmldb.common.Pair;
import com._4paradigm.openmldb.jdbc.SQLInsertMetaData;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
Expand All @@ -32,6 +33,7 @@
import java.sql.Date;
import java.sql.ResultSet;
import java.util.*;
import java.util.stream.Collectors;

public class InsertPreparedStatementImpl implements PreparedStatement {
public static final Charset CHARSET = StandardCharsets.UTF_8;
Expand All @@ -48,7 +50,10 @@ public class InsertPreparedStatementImpl implements PreparedStatement {
private final List<Object> currentDatas;
private final List<DataType> currentDatasType;
private final List<Boolean> hasSet;
private final List<Integer> scehmaIdxs;
// stmt insert idx -> real table schema idx
private final List<Pair<Long, Integer>> schemaIdxes;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the first type of pair is the same as second. set as Integer

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

uint32 is long, and idx type of list is int, no the same.

// used by building row
private final List<Pair<Long, Integer>> sortedIdxes;

private boolean closed = false;
private boolean closeOnComplete = false;
Expand All @@ -63,18 +68,27 @@ public InsertPreparedStatementImpl(String db, String sql, SQLRouter router) thro
this.currentSchema = tempRow.GetSchema();
VectorUint32 idxes = tempRow.GetHoleIdx();

// In stmt order, if no columns in stmt, in schema order
// We'll sort it to schema order later, so needs the map <real_schema_idx, current_data_idx>
schemaIdxes = new ArrayList<>(idxes.size());
// CurrentData and Type order is consistent with insert stmt. We'll do appending in schema order when build
// row.
currentDatas = new ArrayList<>(idxes.size());
currentDatasType = new ArrayList<>(idxes.size());
hasSet = new ArrayList<>(idxes.size());
scehmaIdxs = new ArrayList<>(idxes.size());

for (int i = 0; i < idxes.size(); i++) {
long idx = idxes.get(i);
DataType type = currentSchema.GetColumnType(idx);
Long realIdx = idxes.get(i);
schemaIdxes.add(new Pair<>(realIdx, i));
DataType type = currentSchema.GetColumnType(realIdx);
currentDatasType.add(type);
currentDatas.add(null);
hasSet.add(false);
scehmaIdxs.add(i);
logger.debug("add col {}, {}", currentSchema.GetColumnName(realIdx), type);
}
// SQLInsertRow::AppendXXX order is the schema order(skip the no-hole columns)
sortedIdxes = schemaIdxes.stream().sorted(Comparator.comparing(Pair::getKey))
.collect(Collectors.toList());
}

private SQLInsertRow getSQLInsertRow() throws SQLException {
Expand Down Expand Up @@ -118,14 +132,14 @@ private void checkIdx(int i) throws SQLException {
if (i <= 0) {
throw new SQLException("error sqe number");
}
if (i > scehmaIdxs.size()) {
if (i > schemaIdxes.size()) {
throw new SQLException("out of data range");
}
}

private void checkType(int i, DataType type) throws SQLException {
if (currentDatasType.get(i - 1) != type) {
throw new SQLException("data type not match");
throw new SQLException("data type not match, expect " + currentDatasType.get(i - 1) + ", actual " + type);
}
}

Expand Down Expand Up @@ -206,7 +220,7 @@ public void setBigDecimal(int i, BigDecimal bigDecimal) throws SQLException {
}

private boolean checkNotAllowNull(int i) {
long idx = this.scehmaIdxs.get(i - 1);
Long idx = this.schemaIdxes.get(i - 1).getKey();
return this.currentSchema.IsColumnNotNull(idx);
}

Expand Down Expand Up @@ -300,22 +314,22 @@ public void setObject(int i, Object o, int i1) throws SQLException {

private void buildRow() throws SQLException {
SQLInsertRow currentRow = getSQLInsertRow();

boolean ok = currentRow.Init(stringsLen);
if (!ok) {
throw new SQLException("init row failed");
}

for (int i = 0; i < currentDatasType.size(); i++) {
Object data = currentDatas.get(i);
for (Pair<Long, Integer> sortedIdx : sortedIdxes) {
Integer currentDataIdx = sortedIdx.getValue();
Object data = currentDatas.get(currentDataIdx);
if (data == null) {
ok = currentRow.AppendNULL();
} else {
DataType curType = currentDatasType.get(i);
DataType curType = currentDatasType.get(currentDataIdx);
if (DataType.kTypeBool.equals(curType)) {
ok = currentRow.AppendBool((boolean) data);
} else if (DataType.kTypeDate.equals(curType)) {
java.sql.Date date = (java.sql.Date) data;
Date date = (Date) data;
ok = currentRow.AppendDate(date.getYear() + 1900, date.getMonth() + 1, date.getDate());
} else if (DataType.kTypeDouble.equals(curType)) {
ok = currentRow.AppendDouble((double) data);
Expand All @@ -333,7 +347,7 @@ private void buildRow() throws SQLException {
} else if (DataType.kTypeTimestamp.equals(curType)) {
ok = currentRow.AppendTimestamp((long) data);
} else {
throw new SQLException("unkown data type");
throw new SQLException("unknown data type");
}
}
if (!ok) {
Expand Down Expand Up @@ -423,9 +437,8 @@ public void setArray(int i, Array array) throws SQLException {
}

@Override
@Deprecated
public ResultSetMetaData getMetaData() throws SQLException {
return new SQLInsertMetaData(this.currentDatasType, this.currentSchema, this.scehmaIdxs);
return new SQLInsertMetaData(this.currentDatasType, this.currentSchema, this.schemaIdxes);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import java.sql.ResultSet;
import java.sql.SQLException;
import java.sql.Statement;
import java.sql.Types;
import java.util.ArrayList;
import java.util.List;
import java.util.stream.IntStream;
Expand Down Expand Up @@ -183,7 +184,7 @@ public void testForKafkaConnector() throws SQLException {
String tableName = "kafka_test";
stmt = connection.createStatement();
try {
stmt.execute(String.format("create table if not exists %s(c1 int, c2 string)", tableName));
stmt.execute(String.format("create table if not exists %s(c1 int, c2 string, c3 timestamp)", tableName));
} catch (Exception e) {
Assert.fail();
}
Expand All @@ -198,6 +199,15 @@ public void testForKafkaConnector() throws SQLException {
pstmt.setFetchSize(100);

pstmt.addBatch();
insertSql = "INSERT INTO " +
tableName +
"(`c3`,`c2`) VALUES(?,?)";
pstmt = connection.prepareStatement(insertSql);
Assert.assertEquals(pstmt.getMetaData().getColumnCount(), 2);
// index starts from 1
Assert.assertEquals(pstmt.getMetaData().getColumnType(2), Types.VARCHAR);
Assert.assertEquals(pstmt.getMetaData().getColumnName(2), "c2");


try {
stmt = connection.prepareStatement("DELETE FROM " + tableName + " WHERE c1=1");
Expand Down
Loading