From 3e777af54acb5ed095772c5d9fb59e9700039422 Mon Sep 17 00:00:00 2001 From: Kuba Kaflik Date: Fri, 30 Aug 2024 22:17:59 +0200 Subject: [PATCH] Validate connection in bad state before query execution in the stdlib database/sql driver (#1396) The current behavior of a library is to invalidate connection if it encounters any of ClickHouse errors. Connection in bad state shouldn't be reused. Stdlib driver attempts to use connection without checking whether the connection is in a good condition. --- clickhouse_std.go | 39 ++++++++++++++- tests/issues/1395_test.go | 100 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 137 insertions(+), 2 deletions(-) create mode 100644 tests/issues/1395_test.go diff --git a/clickhouse_std.go b/clickhouse_std.go index 7b18480a2b..466d345fb2 100644 --- a/clickhouse_std.go +++ b/clickhouse_std.go @@ -239,12 +239,32 @@ func (std *stdDriver) ResetSession(ctx context.Context) error { var _ driver.SessionResetter = (*stdDriver)(nil) -func (std *stdDriver) Ping(ctx context.Context) error { return std.conn.ping(ctx) } +func (std *stdDriver) Ping(ctx context.Context) error { + if std.conn.isBad() { + std.debugf("Ping: connection is bad") + return driver.ErrBadConn + } + + return std.conn.ping(ctx) +} var _ driver.Pinger = (*stdDriver)(nil) -func (std *stdDriver) Begin() (driver.Tx, error) { return std, nil } +func (std *stdDriver) Begin() (driver.Tx, error) { + if std.conn.isBad() { + std.debugf("Begin: connection is bad") + return nil, driver.ErrBadConn + } + + return std, nil +} + func (std *stdDriver) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) { + if std.conn.isBad() { + std.debugf("BeginTx: connection is bad") + return nil, driver.ErrBadConn + } + return std, nil } @@ -280,6 +300,11 @@ func (std *stdDriver) CheckNamedValue(nv *driver.NamedValue) error { return nil var _ driver.NamedValueChecker = (*stdDriver)(nil) func (std *stdDriver) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) { + if std.conn.isBad() { + std.debugf("ExecContext: connection is bad") + return nil, driver.ErrBadConn + } + var err error if options := queryOptions(ctx); options.async.ok { err = std.conn.asyncInsert(ctx, query, options.async.wait, rebind(args)...) @@ -299,6 +324,11 @@ func (std *stdDriver) ExecContext(ctx context.Context, query string, args []driv } func (std *stdDriver) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) { + if std.conn.isBad() { + std.debugf("QueryContext: connection is bad") + return nil, driver.ErrBadConn + } + r, err := std.conn.query(ctx, func(*connect, error) {}, query, rebind(args)...) if isConnBrokenError(err) { std.debugf("QueryContext got a fatal error, resetting connection: %v\n", err) @@ -319,6 +349,11 @@ func (std *stdDriver) Prepare(query string) (driver.Stmt, error) { } func (std *stdDriver) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) { + if std.conn.isBad() { + std.debugf("PrepareContext: connection is bad") + return nil, driver.ErrBadConn + } + batch, err := std.conn.prepareBatch(ctx, query, ldriver.PrepareBatchOptions{}, func(*connect, error) {}, func(context.Context) (*connect, error) { return nil, nil }) if err != nil { if isConnBrokenError(err) { diff --git a/tests/issues/1395_test.go b/tests/issues/1395_test.go new file mode 100644 index 0000000000..7292733ced --- /dev/null +++ b/tests/issues/1395_test.go @@ -0,0 +1,100 @@ +// Licensed to ClickHouse, Inc. under one or more contributor +// license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright +// ownership. ClickHouse, Inc. licenses this file to you under +// the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package issues + +import ( + "context" + "database/sql" + "database/sql/driver" + "testing" + + "github.com/ClickHouse/clickhouse-go/v2" + clickhouse_tests "github.com/ClickHouse/clickhouse-go/v2/tests" + "github.com/pkg/errors" + "github.com/stretchr/testify/require" +) + +func Test1395(t *testing.T) { + testEnv, err := clickhouse_tests.GetTestEnvironment("issues") + require.NoError(t, err) + opts := clickhouse_tests.ClientOptionsFromEnv(testEnv, clickhouse.Settings{}, false) + conn, err := sql.Open("clickhouse", clickhouse_tests.OptionsToDSN(&opts)) + require.NoError(t, err) + + ctx := context.Background() + + singleConn, err := conn.Conn(ctx) + if err != nil { + t.Fatalf("Get single conn from pool: %v", err) + } + + tx1 := func(c *sql.Conn) error { + tx, err := c.BeginTx(ctx, nil) + if err != nil { + return errors.Wrap(err, "begin tx") + } + defer tx.Rollback() + + _, err = tx.ExecContext(ctx, ` +CREATE TABLE IF NOT EXISTS test_table +ON CLUSTER my +(id UInt32, name String) +ENGINE = MergeTree() +ORDER BY id`) + if err != nil { + return errors.Wrap(err, "create table") + } + + err = tx.Commit() + if err != nil { + return errors.Wrap(err, "commit tx") + } + + return nil + } + + err = tx1(singleConn) + require.Error(t, err, "expected error due to cluster is not configured") + + tx2 := func(c *sql.Conn) error { + tx, err := c.BeginTx(ctx, nil) + if err != nil { + return errors.Wrap(err, "begin tx") + } + defer tx.Rollback() + + _, err = tx.ExecContext(ctx, "INSERT INTO test_table (id, name) VALUES (?, ?)", 1, "test_name") + if err != nil { + return errors.Wrap(err, "failed to insert record") + } + err = tx.Commit() + if err != nil { + return errors.Wrap(err, "commit tx") + } + + return nil + } + require.NotPanics( + t, + func() { + err := tx2(singleConn) + require.ErrorIs(t, err, driver.ErrBadConn) + }, + "must not panics", + ) +}