Skip to content

Commit

Permalink
Allow INSERT null for SQL Server varbinary type
Browse files Browse the repository at this point in the history
  • Loading branch information
ebyhr committed Sep 13, 2020
1 parent 75f426e commit 05b1226
Show file tree
Hide file tree
Showing 2 changed files with 127 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
package io.prestosql.plugin.sqlserver;

import com.google.common.base.Joiner;
import io.airlift.slice.Slice;
import io.prestosql.plugin.jdbc.BaseJdbcClient;
import io.prestosql.plugin.jdbc.BaseJdbcConfig;
import io.prestosql.plugin.jdbc.ColumnMapping;
Expand All @@ -24,28 +25,36 @@
import io.prestosql.plugin.jdbc.JdbcTypeHandle;
import io.prestosql.plugin.jdbc.PredicatePushdownController;
import io.prestosql.plugin.jdbc.PredicatePushdownController.DomainPushdownResult;
import io.prestosql.plugin.jdbc.SliceWriteFunction;
import io.prestosql.plugin.jdbc.WriteMapping;
import io.prestosql.spi.PrestoException;
import io.prestosql.spi.connector.ConnectorSession;
import io.prestosql.spi.connector.SchemaTableName;
import io.prestosql.spi.predicate.Domain;
import io.prestosql.spi.type.CharType;
import io.prestosql.spi.type.Type;
import io.prestosql.spi.type.VarbinaryType;
import io.prestosql.spi.type.VarcharType;

import javax.inject.Inject;

import java.sql.Connection;
import java.sql.PreparedStatement;
import java.sql.SQLException;
import java.util.List;
import java.util.Optional;
import java.util.function.BiFunction;

import static com.google.common.base.Preconditions.checkArgument;
import static io.airlift.slice.Slices.wrappedBuffer;
import static io.prestosql.plugin.jdbc.ColumnMapping.DISABLE_PUSHDOWN;
import static io.prestosql.plugin.jdbc.JdbcErrorCode.JDBC_ERROR;
import static io.prestosql.plugin.jdbc.StandardColumnMappings.booleanWriteFunction;
import static io.prestosql.plugin.jdbc.StandardColumnMappings.charWriteFunction;
import static io.prestosql.plugin.jdbc.StandardColumnMappings.varcharWriteFunction;
import static io.prestosql.spi.StandardErrorCode.NOT_SUPPORTED;
import static io.prestosql.spi.type.BooleanType.BOOLEAN;
import static io.prestosql.spi.type.VarbinaryType.VARBINARY;
import static io.prestosql.spi.type.Varchars.isVarcharType;
import static java.lang.String.format;
import static java.util.stream.Collectors.joining;
Expand Down Expand Up @@ -117,6 +126,14 @@ public Optional<ColumnMapping> toPrestoType(ConnectorSession session, Connection
if (mapping.isPresent()) {
return mapping;
}

String jdbcTypeName = typeHandle.getJdbcTypeName()
.orElseThrow(() -> new PrestoException(JDBC_ERROR, "Type name is missing: " + typeHandle));

if (jdbcTypeName.equals("varbinary")) {
return Optional.of(varbinaryColumnMapping());
}

// TODO (https://github.com/prestosql/presto/issues/4593) implement proper type mapping
return super.toPrestoType(session, connection, typeHandle)
.map(columnMapping -> new ColumnMapping(
Expand Down Expand Up @@ -157,6 +174,10 @@ public WriteMapping toWriteMapping(ConnectorSession session, Type type)
return WriteMapping.sliceMapping(dataType, charWriteFunction());
}

if (type instanceof VarbinaryType) {
return WriteMapping.sliceMapping("varbinary(max)", varbinaryWriteFunction());
}

// TODO implement proper type mapping
return super.toWriteMapping(session, type);
}
Expand Down Expand Up @@ -186,4 +207,32 @@ private static String singleQuote(String literal)
{
return "\'" + literal + "\'";
}

public static ColumnMapping varbinaryColumnMapping()
{
return ColumnMapping.sliceMapping(
VARBINARY,
(resultSet, columnIndex) -> wrappedBuffer(resultSet.getBytes(columnIndex)),
varbinaryWriteFunction(),
DISABLE_PUSHDOWN);
}

private static SliceWriteFunction varbinaryWriteFunction()
{
return new SliceWriteFunction() {
@Override
public void set(PreparedStatement statement, int index, Slice value)
throws SQLException
{
statement.setBytes(index, value.getBytes());
}

@Override
public void setNull(PreparedStatement statement, int index)
throws SQLException
{
statement.setBytes(index, null);
}
};
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.prestosql.plugin.sqlserver;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import io.prestosql.Session;
import io.prestosql.testing.AbstractTestQueryFramework;
import io.prestosql.testing.QueryRunner;
import io.prestosql.testing.datatype.CreateAsSelectDataSetup;
import io.prestosql.testing.datatype.DataSetup;
import io.prestosql.testing.datatype.DataTypeTest;
import io.prestosql.testing.sql.PrestoSqlExecutor;
import org.testng.annotations.AfterClass;
import org.testng.annotations.Test;

import static io.prestosql.plugin.sqlserver.SqlServerQueryRunner.createSqlServerQueryRunner;
import static io.prestosql.testing.datatype.DataType.varbinaryDataType;
import static java.nio.charset.StandardCharsets.UTF_16LE;
import static java.nio.charset.StandardCharsets.UTF_8;

public class TestSqlServerTypeMapping
extends AbstractTestQueryFramework
{
private TestingSqlServer sqlServer;

@Override
protected QueryRunner createQueryRunner()
throws Exception
{
sqlServer = new TestingSqlServer();
sqlServer.start();
return createSqlServerQueryRunner(
sqlServer,
ImmutableMap.of(),
ImmutableList.of());
}

@AfterClass(alwaysRun = true)
public final void destroy()
{
sqlServer.close();
}

@Test
public void testVarbinary()
{
DataTypeTest.create()
.addRoundTrip(varbinaryDataType(), null)
.addRoundTrip(varbinaryDataType(), "hello".getBytes(UTF_8))
.addRoundTrip(varbinaryDataType(), "Piękna łąka w 東京都".getBytes(UTF_8))
.addRoundTrip(varbinaryDataType(), "Bag full of 💰".getBytes(UTF_16LE))
.addRoundTrip(varbinaryDataType(), new byte[] {})
.addRoundTrip(varbinaryDataType(), new byte[] {0, 1, 2, 3, 4, 5, 6, 7, 8, 13, -7, 54, 122, -89, 0, 0, 0})
.execute(getQueryRunner(), prestoCreateAsSelect("test_varbinary"));
}

private DataSetup prestoCreateAsSelect(String tableNamePrefix)
{
return prestoCreateAsSelect(getSession(), tableNamePrefix);
}

private DataSetup prestoCreateAsSelect(Session session, String tableNamePrefix)
{
return new CreateAsSelectDataSetup(new PrestoSqlExecutor(getQueryRunner(), session), tableNamePrefix);
}
}

0 comments on commit 05b1226

Please sign in to comment.