diff --git a/plugin/trino-clickhouse/src/main/java/io/trino/plugin/clickhouse/ClickHouseClient.java b/plugin/trino-clickhouse/src/main/java/io/trino/plugin/clickhouse/ClickHouseClient.java index 7d1dcff16826..c0cdbb263418 100644 --- a/plugin/trino-clickhouse/src/main/java/io/trino/plugin/clickhouse/ClickHouseClient.java +++ b/plugin/trino-clickhouse/src/main/java/io/trino/plugin/clickhouse/ClickHouseClient.java @@ -94,6 +94,7 @@ import java.util.UUID; import java.util.concurrent.atomic.AtomicReference; import java.util.function.BiFunction; +import java.util.stream.Collectors; import static com.clickhouse.data.ClickHouseValues.convertToQuotedString; import static com.google.common.base.Preconditions.checkArgument; @@ -299,7 +300,7 @@ protected String createTableSql(RemoteTableName remoteTableName, List co formatProperty(ClickHouseTableProperties.getOrderBy(tableProperties)).ifPresent(value -> tableOptions.add("ORDER BY " + value)); formatProperty(ClickHouseTableProperties.getPrimaryKey(tableProperties)).ifPresent(value -> tableOptions.add("PRIMARY KEY " + value)); formatProperty(ClickHouseTableProperties.getPartitionBy(tableProperties)).ifPresent(value -> tableOptions.add("PARTITION BY " + value)); - ClickHouseTableProperties.getSampleBy(tableProperties).ifPresent(value -> tableOptions.add("SAMPLE BY " + value)); + ClickHouseTableProperties.getSampleBy(tableProperties).ifPresent(value -> tableOptions.add("SAMPLE BY " + quoted(value))); tableMetadata.getComment().ifPresent(comment -> tableOptions.add(format("COMMENT %s", clickhouseVarcharLiteral(comment)))); return format("CREATE TABLE %s (%s) %s", quoted(remoteTableName), join(", ", columns), join(" ", tableOptions.build())); @@ -363,7 +364,7 @@ public void setTableProperties(ConnectorSession session, JdbcTableHandle handle, .collect(toImmutableMap(Entry::getKey, entry -> entry.getValue().orElseThrow())); ImmutableList.Builder tableOptions = ImmutableList.builder(); - ClickHouseTableProperties.getSampleBy(properties).ifPresent(value -> tableOptions.add("SAMPLE BY " + value)); + ClickHouseTableProperties.getSampleBy(properties).ifPresent(value -> tableOptions.add("SAMPLE BY " + quoted(value))); try (Connection connection = connectionFactory.openConnection(session)) { String sql = format( @@ -720,10 +721,10 @@ private Optional formatProperty(List prop) } if (prop.size() == 1) { // only one column - return Optional.of(prop.get(0)); + return Optional.of(quoted(prop.get(0))); } // include more than one column - return Optional.of("(" + String.join(",", prop) + ")"); + return Optional.of(prop.stream().map(this::quoted).collect(Collectors.joining(",", "(", ")"))); } private static LongWriteFunction uInt8WriteFunction(ClickHouseVersion version) diff --git a/plugin/trino-clickhouse/src/test/java/io/trino/plugin/clickhouse/BaseClickHouseConnectorTest.java b/plugin/trino-clickhouse/src/test/java/io/trino/plugin/clickhouse/BaseClickHouseConnectorTest.java index 90383f5bee39..d4203dce5a61 100644 --- a/plugin/trino-clickhouse/src/test/java/io/trino/plugin/clickhouse/BaseClickHouseConnectorTest.java +++ b/plugin/trino-clickhouse/src/test/java/io/trino/plugin/clickhouse/BaseClickHouseConnectorTest.java @@ -83,6 +83,15 @@ protected boolean hasBehavior(TestingConnectorBehavior connectorBehavior) } } + @Test + public void testSampleBySqlInjection() + { + assertQueryFails("CREATE TABLE test (p1 int NOT NULL, p2 boolean NOT NULL, x VARCHAR) WITH (engine = 'MergeTree', order_by = ARRAY['p1', 'p2'], primary_key = ARRAY['p1', 'p2'], sample_by = 'p2; drop table tpch.nation')", "(?s).*Missing columns: 'p2; drop table tpch.nation.*"); + assertUpdate("CREATE TABLE test (p1 int NOT NULL, p2 boolean NOT NULL, x VARCHAR) WITH (engine = 'MergeTree', order_by = ARRAY['p1', 'p2'], primary_key = ARRAY['p1', 'p2'], sample_by = 'p2')"); + assertQueryFails("ALTER TABLE test SET PROPERTIES sample_by = 'p2; drop table tpch.nation'", "(?s).*Missing columns: 'p2; drop table tpch.nation.*"); + assertUpdate("ALTER TABLE test SET PROPERTIES sample_by = 'p2'"); + } + @Override @Test(dataProvider = "testColumnNameDataProvider") public void testColumnName(String columnName) @@ -352,7 +361,7 @@ public void testDifferentEngine() // MergeTree with optional assertUpdate("CREATE TABLE " + tableName + " (id int NOT NULL, x VARCHAR, logdate DATE NOT NULL) WITH " + - "(engine = 'MergeTree', order_by = ARRAY['id'], partition_by = ARRAY['toYYYYMM(logdate)'])"); + "(engine = 'MergeTree', order_by = ARRAY['id'], partition_by = ARRAY['logdate'])"); assertTrue(getQueryRunner().tableExists(getSession(), tableName)); assertUpdate("DROP TABLE " + tableName);