Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Common HTTP insert query normalization #1341

Merged
merged 1 commit into from
Jul 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 53 additions & 0 deletions batch.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
// Licensed to ClickHouse, Inc. under one or more contributor
// license agreements. See the NOTICE file distributed with
// this work for additional information regarding copyright
// ownership. ClickHouse, Inc. licenses this file to you under
// the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

package clickhouse

import (
"fmt"
"regexp"
"strings"

"github.com/pkg/errors"
)

var normalizeInsertQueryMatch = regexp.MustCompile(`(?i)(INSERT\s+INTO\s+([^( ]+)(?:\s*\([^()]*(?:\([^()]*\)[^()]*)*\))?)(?:\s*VALUES)?`)
var extractInsertColumnsMatch = regexp.MustCompile(`INSERT INTO .+\s\((?P<Columns>.+)\)$`)

func extractNormalizedInsertQueryAndColumns(query string) (normalizedQuery string, tableName string, columns []string, err error) {
matches := normalizeInsertQueryMatch.FindStringSubmatch(query)
if len(matches) == 0 {
err = errors.Errorf("invalid INSERT query: %s", query)
return
}

normalizedQuery = fmt.Sprintf("%s FORMAT Native", matches[1])
tableName = matches[2]

columns = make([]string, 0)
matches = extractInsertColumnsMatch.FindStringSubmatch(matches[1])
if len(matches) == 2 {
columns = strings.Split(matches[1], ",")
for i := range columns {
// refers to https://clickhouse.com/docs/en/sql-reference/syntax#identifiers
// we can use identifiers with double quotes or backticks, for example: "id", `id`, but not both, like `"id"`.
columns[i] = strings.Trim(strings.Trim(strings.TrimSpace(columns[i]), "\""), "`")
}
}

return
}
89 changes: 89 additions & 0 deletions batch_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
// Licensed to ClickHouse, Inc. under one or more contributor
// license agreements. See the NOTICE file distributed with
// this work for additional information regarding copyright
// ownership. ClickHouse, Inc. licenses this file to you under
// the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

package clickhouse

import (
"testing"

"github.com/stretchr/testify/assert"
)

func TestExtractNormalizedInsertQueryAndColumns(t *testing.T) {
var testCases = []struct {
query string
expectedNormalizedQuery string
expectedTableName string
expectedColumns []string
expectedError bool
}{
{
query: "INSERT INTO table_name (col1, col2) VALUES (1, 2)",
expectedNormalizedQuery: "INSERT INTO table_name (col1, col2) FORMAT Native",
expectedTableName: "table_name",
expectedColumns: []string{"col1", "col2"},
expectedError: false,
},
{
query: "INSERT INTO `db`.`table_name` (col1, col2) VALUES (1, 2)",
expectedNormalizedQuery: "INSERT INTO `db`.`table_name` (col1, col2) FORMAT Native",
expectedTableName: "`db`.`table_name`",
expectedColumns: []string{"col1", "col2"},
expectedError: false,
},
{
query: "INSERT INTO table_name (col1, col2) VALUES (1, 2) FORMAT Native",
expectedNormalizedQuery: "INSERT INTO table_name (col1, col2) FORMAT Native",
expectedTableName: "table_name",
expectedColumns: []string{"col1", "col2"},
expectedError: false,
},
{
query: "INSERT INTO table_name",
expectedNormalizedQuery: "INSERT INTO table_name FORMAT Native",
expectedTableName: "table_name",
expectedColumns: []string{},
expectedError: false,
},
{
query: "INSERT INTO table_name FORMAT Native",
expectedNormalizedQuery: "INSERT INTO table_name FORMAT Native",
expectedTableName: "table_name",
expectedColumns: []string{},
expectedError: false,
},
{
query: "SELECT * FROM table_name",
expectedError: true,
},
}

for _, tc := range testCases {
t.Run(tc.query, func(t *testing.T) {
normalizedQuery, tableName, columns, err := extractNormalizedInsertQueryAndColumns(tc.query)
if tc.expectedError {
assert.Error(t, err)
return
}

assert.NoError(t, err)
assert.Equal(t, tc.expectedNormalizedQuery, normalizedQuery)
assert.Equal(t, tc.expectedTableName, tableName)
assert.Equal(t, tc.expectedColumns, columns)
})
}
}
28 changes: 4 additions & 24 deletions conn_batch.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ import (
"os"
"regexp"
"slices"
"strings"
"time"

"github.com/pkg/errors"
Expand All @@ -37,29 +36,10 @@ var insertMatch = regexp.MustCompile(`(?i)(INSERT\s+INTO\s+[^( ]+(?:\s*\([^()]*(
var columnMatch = regexp.MustCompile(`INSERT INTO .+\s\((?P<Columns>.+)\)$`)

func (c *connect) prepareBatch(ctx context.Context, query string, opts driver.PrepareBatchOptions, release func(*connect, error), acquire func(context.Context) (*connect, error)) (driver.Batch, error) {
//defer func() {
// if err := recover(); err != nil {
// fmt.Printf("panic occurred on %d:\n", c.num)
// }
//}()
subMatches := insertMatch.FindStringSubmatch(query)
if len(subMatches) > 1 {
query = subMatches[1]
} else {
return nil, errors.New("invalid query")
}

colMatch := columnMatch.FindStringSubmatch(query)
var columns []string
if len(colMatch) == 2 {
columns = strings.Split(colMatch[1], ",")
for i := range columns {
// refers to https://clickhouse.com/docs/en/sql-reference/syntax#identifiers
// we can use identifiers with double quotes or backticks, for example: "id", `id`, but not both, like `"id"`.
columns[i] = strings.Trim(strings.Trim(strings.TrimSpace(columns[i]), "\""), "`")
}
query, _, queryColumns, verr := extractNormalizedInsertQueryAndColumns(query)
if verr != nil {
return nil, verr
}
query += " VALUES"

options := queryOptions(ctx)
if deadline, ok := ctx.Deadline(); ok {
Expand All @@ -79,7 +59,7 @@ func (c *connect) prepareBatch(ctx context.Context, query string, opts driver.Pr
return nil, err
}
// resort batch to specified columns
if err = block.SortColumns(columns); err != nil {
if err = block.SortColumns(queryColumns); err != nil {
return nil, err
}

Expand Down
32 changes: 8 additions & 24 deletions conn_http_batch.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,47 +19,31 @@ package clickhouse

import (
"context"
"errors"
"fmt"
"io"
"regexp"
"slices"
"strings"

"github.com/ClickHouse/clickhouse-go/v2/lib/column"
"github.com/ClickHouse/clickhouse-go/v2/lib/driver"
"github.com/ClickHouse/clickhouse-go/v2/lib/proto"
)

// \x60 represents a backtick
var httpInsertRe = regexp.MustCompile(`(?i)^INSERT INTO\s+\x60?([\w.^\(]+)\x60?\s*(\([^\)]*\))?`)

// release is ignored, because http used by std with empty release function.
// Also opts ignored because all options unused in http batch.
func (h *httpConnect) prepareBatch(ctx context.Context, query string, opts driver.PrepareBatchOptions, release func(*connect, error), acquire func(context.Context) (*connect, error)) (driver.Batch, error) {
matches := httpInsertRe.FindStringSubmatch(query)
if len(matches) < 3 {
return nil, errors.New("cannot get table name from query")
}
tableName := matches[1]
var rColumns []string
if matches[2] != "" {
colMatch := strings.TrimSuffix(strings.TrimPrefix(matches[2], "("), ")")
rColumns = strings.Split(colMatch, ",")
for i := range rColumns {
rColumns[i] = strings.Trim(strings.TrimSpace(rColumns[i]), "`")
}
query, tableName, queryColumns, err := extractNormalizedInsertQueryAndColumns(query)
if err != nil {
return nil, err
}
query = "INSERT INTO " + tableName + " FORMAT Native"
queryTableSchema := "DESCRIBE TABLE " + tableName
r, err := h.query(ctx, release, queryTableSchema)

describeTableQuery := fmt.Sprintf("DESCRIBE TABLE %s", tableName)
r, err := h.query(ctx, release, describeTableQuery)
if err != nil {
return nil, err
}

block := &proto.Block{}

// get Table columns and types
columns := make(map[string]string)
var colNames []string
for r.Next() {
Expand All @@ -81,7 +65,7 @@ func (h *httpConnect) prepareBatch(ctx context.Context, query string, opts drive
columns[colName] = colType
}

switch len(rColumns) {
switch len(queryColumns) {
case 0:
for _, colName := range colNames {
if err = block.AddColumn(colName, column.Type(columns[colName])); err != nil {
Expand All @@ -90,7 +74,7 @@ func (h *httpConnect) prepareBatch(ctx context.Context, query string, opts drive
}
default:
// user has requested specific columns so only include these
for _, colName := range rColumns {
for _, colName := range queryColumns {
if colType, ok := columns[colName]; ok {
if err = block.AddColumn(colName, column.Type(colType)); err != nil {
return nil, err
Expand Down
2 changes: 1 addition & 1 deletion tests/batch_block_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ import (
func TestBatchAppendRows(t *testing.T) {
te, err := GetTestEnvironment(testSet)
require.NoError(t, err)
opts := ClientOptionsFromEnv(te, clickhouse.Settings{})
opts := ClientOptionsFromEnv(te, clickhouse.Settings{}, false)

conn, err := GetConnectionWithOptions(&opts)
require.NoError(t, err)
Expand Down
2 changes: 1 addition & 1 deletion tests/batch_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ import (
func TestBatchContextCancellation(t *testing.T) {
te, err := GetTestEnvironment(testSet)
require.NoError(t, err)
opts := ClientOptionsFromEnv(te, clickhouse.Settings{})
opts := ClientOptionsFromEnv(te, clickhouse.Settings{}, false)
opts.MaxOpenConns = 1
conn, err := GetConnectionWithOptions(&opts)
require.NoError(t, err)
Expand Down
7 changes: 4 additions & 3 deletions tests/client_info_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,13 @@ package tests
import (
"context"
"fmt"
"runtime"
"testing"

"github.com/ClickHouse/clickhouse-go/v2"
"github.com/ClickHouse/clickhouse-go/v2/lib/driver"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"runtime"
"testing"
)

func TestClientInfo(t *testing.T) {
Expand Down Expand Up @@ -89,7 +90,7 @@ func TestClientInfo(t *testing.T) {

for name, testCase := range testCases {
t.Run(name, func(t *testing.T) {
opts := ClientOptionsFromEnv(env, clickhouse.Settings{})
opts := ClientOptionsFromEnv(env, clickhouse.Settings{}, false)
opts.ClientInfo = testCase.clientInfo

conn, err := clickhouse.Open(&opts)
Expand Down
4 changes: 2 additions & 2 deletions tests/conn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,7 @@ func TestConnCustomDialStrategy(t *testing.T) {
env, err := GetTestEnvironment(testSet)
require.NoError(t, err)

opts := ClientOptionsFromEnv(env, clickhouse.Settings{})
opts := ClientOptionsFromEnv(env, clickhouse.Settings{}, false)
validAddr := opts.Addr[0]
opts.Addr = []string{"invalid.host:9001"}

Expand Down Expand Up @@ -342,7 +342,7 @@ func TestConnectionExpiresIdleConnection(t *testing.T) {
expectedConnections := getActiveConnections(t, baseConn)

// when the client is configured to expire idle connections after 1/10 of a second
opts := ClientOptionsFromEnv(testEnv, clickhouse.Settings{})
opts := ClientOptionsFromEnv(testEnv, clickhouse.Settings{}, false)
opts.MaxIdleConns = 20
opts.MaxOpenConns = 20
opts.ConnMaxLifetime = time.Second / 10
Expand Down
35 changes: 35 additions & 0 deletions tests/issues/1329_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
package issues

import (
"database/sql"
"fmt"
"testing"

"github.com/ClickHouse/clickhouse-go/v2"
clickhouse_tests "github.com/ClickHouse/clickhouse-go/v2/tests"
"github.com/stretchr/testify/require"
)

func Test1329(t *testing.T) {
testEnv, err := clickhouse_tests.GetTestEnvironment("issues")
require.NoError(t, err)
opts := clickhouse_tests.ClientOptionsFromEnv(testEnv, clickhouse.Settings{}, true)
conn, err := sql.Open("clickhouse", clickhouse_tests.OptionsToDSN(&opts))
require.NoError(t, err)

_, err = conn.Exec(`CREATE TABLE test_1329 (Col String) Engine MergeTree() ORDER BY tuple()`)
require.NoError(t, err)
t.Cleanup(func() {
_, _ = conn.Exec("DROP TABLE test_1329")
})

scope, err := conn.Begin()

batch, err := scope.Prepare(fmt.Sprintf("INSERT INTO `%s`.`test_1329`", testEnv.Database))
require.NoError(t, err)
_, err = batch.Exec(
"str",
)
require.NoError(t, err)
require.NoError(t, scope.Commit())
}
7 changes: 4 additions & 3 deletions tests/issues/957_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,12 @@ package issues

import (
"context"
"testing"
"time"

"github.com/ClickHouse/clickhouse-go/v2"
clickhouse_tests "github.com/ClickHouse/clickhouse-go/v2/tests"
"github.com/stretchr/testify/require"
"testing"
"time"
)

func Test957(t *testing.T) {
Expand All @@ -33,7 +34,7 @@ func Test957(t *testing.T) {
require.NoError(t, err)

// when the client is configured to use the test environment
opts := clickhouse_tests.ClientOptionsFromEnv(testEnv, clickhouse.Settings{})
opts := clickhouse_tests.ClientOptionsFromEnv(testEnv, clickhouse.Settings{}, false)
// and the client is configured to have only 1 connection
opts.MaxIdleConns = 2
opts.MaxOpenConns = 1
Expand Down
Loading
Loading