Skip to content

Commit

Permalink
Merge pull request #91 from hellofresh/feature/allow-cancelation-long…
Browse files Browse the repository at this point in the history
…-queries

Allow cancelation of long queries
  • Loading branch information
rafaeljesus authored Mar 6, 2018
2 parents 537cf56 + 0e76947 commit 2b27fed
Show file tree
Hide file tree
Showing 8 changed files with 62 additions and 36 deletions.
4 changes: 2 additions & 2 deletions cmd/steal.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,8 @@ func NewStealCmd() *cobra.Command {
cmd.PersistentFlags().StringVarP(&opts.from, "from", "f", "root:root@tcp(localhost:3306)/klepto", "Database dsn to steal from")
cmd.PersistentFlags().StringVarP(&opts.to, "to", "t", "os://stdout/", "Database to output to (default writes to stdOut)")
cmd.PersistentFlags().IntVar(&opts.concurrency, "concurrency", 4, "Sets the amount of dumps to be performed concurrently")
cmd.PersistentFlags().StringVar(&opts.readOpts.timeout, "read-timeout", "30s", "Sets the timeout for all read operations")
cmd.PersistentFlags().StringVar(&opts.writeOpts.timeout, "write-timeout", "30s", "Sets the timeout for all write operations")
cmd.PersistentFlags().StringVar(&opts.readOpts.timeout, "read-timeout", "5m", "Sets the timeout for read operations")
cmd.PersistentFlags().StringVar(&opts.writeOpts.timeout, "write-timeout", "30s", "Sets the timeout for write operations")
cmd.PersistentFlags().StringVar(&opts.readOpts.maxConnLifetime, "read-conn-lifetime", "0", "Sets the maximum amount of time a connection may be reused on the read database")
cmd.PersistentFlags().IntVar(&opts.readOpts.maxConns, "read-max-conns", 10, "Sets the maximum number of open connections to the read database")
cmd.PersistentFlags().IntVar(&opts.readOpts.maxIdleConns, "read-max-idle-conns", 0, "Sets the maximum number of connections in the idle connection pool for the read database")
Expand Down
11 changes: 6 additions & 5 deletions features/mysql_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"os"
"path"
"testing"
"time"

"github.com/go-sql-driver/mysql"
"github.com/hellofresh/klepto/pkg/config"
Expand All @@ -19,15 +20,15 @@ import (

type MysqlTestSuite struct {
suite.Suite

rootDSN string
rootConnection *sql.DB

databases []string
databases []string
timeout time.Duration
}

func TestMysqlTestSuite(t *testing.T) {
suite.Run(t, new(MysqlTestSuite))
s := &MysqlTestSuite{timeout: time.Second * 3}
suite.Run(t, s)
}

func (s *MysqlTestSuite) TestExample() {
Expand All @@ -36,7 +37,7 @@ func (s *MysqlTestSuite) TestExample() {

s.loadFixture(readDSN, "mysql_simple.sql")

rdr, err := reader.Connect(reader.ConnOpts{DSN: readDSN})
rdr, err := reader.Connect(reader.ConnOpts{DSN: readDSN, Timeout: s.timeout})
s.Require().NoError(err, "Unable to create reader")
defer rdr.Close()

Expand Down
11 changes: 6 additions & 5 deletions features/postgres_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"os"
"path"
"testing"
"time"

"strconv"

Expand All @@ -21,11 +22,10 @@ import (

type PostgresTestSuite struct {
suite.Suite

rootDSN string
rootConnection *sql.DB

databases []string
databases []string
timeout time.Duration
}

type tableInfo struct {
Expand All @@ -35,7 +35,8 @@ type tableInfo struct {
}

func TestPostgresTestSuite(t *testing.T) {
suite.Run(t, new(PostgresTestSuite))
s := &PostgresTestSuite{timeout: time.Second * 3}
suite.Run(t, s)
}

func (s *PostgresTestSuite) TestExample() {
Expand All @@ -44,7 +45,7 @@ func (s *PostgresTestSuite) TestExample() {

s.loadFixture(readDSN, "pg_simple.sql")

rdr, err := reader.Connect(reader.ConnOpts{DSN: readDSN})
rdr, err := reader.Connect(reader.ConnOpts{DSN: readDSN, Timeout: s.timeout})
s.Require().NoError(err, "Unable to create reader")
defer rdr.Close()

Expand Down
50 changes: 37 additions & 13 deletions pkg/reader/generic/sql.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
package generic

import (
"context"
"database/sql"
"fmt"
"sync"
"time"

sq "github.com/Masterminds/squirrel"
"github.com/hellofresh/klepto/pkg/database"
Expand All @@ -20,6 +22,8 @@ type (
tables []string
// columns is a cache variable for tables and there columns in the db
columns sync.Map
// timeout is the sql read operation timeout
timeout time.Duration
}

SqlEngine interface {
Expand All @@ -44,8 +48,8 @@ type (
)

// NewSqlReader creates a new sql reader
func NewSqlReader(se SqlEngine) *SqlReader {
return &SqlReader{SqlEngine: se}
func NewSqlReader(se SqlEngine, t time.Duration) *SqlReader {
return &SqlReader{SqlEngine: se, timeout: t}
}

// GetTables gets a list of all tables in the database
Expand Down Expand Up @@ -93,20 +97,40 @@ func (s *SqlReader) ReadTable(tableName string, rowChan chan<- database.Row, opt
opts.Columns = s.formatColumns(tableName, columns)
}

query, err := s.buildQuery(tableName, opts)
var (
query sq.SelectBuilder
err error
)
query, err = s.buildQuery(tableName, opts)
if err != nil {
return errors.Wrapf(err, "failed to build query for %s", tableName)
}

rows, err := query.RunWith(s.GetConnection()).Query()
if err != nil {
querySQL, queryParams, _ := query.ToSql()
logger.WithFields(log.Fields{
"query": querySQL,
"params": queryParams,
}).Warn("failed to query rows")

return errors.Wrap(err, "failed to query rows")
var rows *sql.Rows
ctx, cancel := context.WithTimeout(context.Background(), s.timeout)
defer cancel()

errchan := make(chan error)
go func() {
defer close(errchan)
rows, err = query.RunWith(s.GetConnection()).QueryContext(ctx)
errchan <- err
}()

select {
case <-ctx.Done():
return errors.Wrapf(ctx.Err(), fmt.Sprintf("timeout during read %s table", tableName))
case err := <-errchan:
if err != nil {
querySQL, queryParams, _ := query.ToSql()
logger.WithError(err).
WithFields(log.Fields{
"query": querySQL,
"params": queryParams,
}).Warn("failed to query rows")
return errors.Wrap(err, "failed to query rows")
}
break
}

return s.publishRows(rows, rowChan, tableName)
Expand Down Expand Up @@ -179,7 +203,7 @@ func (s *SqlReader) publishRows(rows *sql.Rows, rowChan chan<- database.Row, tab
}

if err := rows.Scan(fieldPointers...); err != nil {
log.WithError(err).Warn("failed to fetch row")
log.WithError(err).WithField("table", tableName).Warn("failed to fetch row")
continue
}

Expand Down
2 changes: 1 addition & 1 deletion pkg/reader/mysql/mysql.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ func (m *driver) NewConnection(opts reader.ConnOpts) (reader.Reader, error) {
conn.SetMaxIdleConns(opts.MaxIdleConns)
conn.SetConnMaxLifetime(opts.MaxConnLifetime)

return NewStorage(conn), nil
return NewStorage(conn, opts.Timeout), nil
}

func init() {
Expand Down
4 changes: 2 additions & 2 deletions pkg/reader/mysql/reader.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@ type storage struct {
}

// NewStorage ...
func NewStorage(conn *sql.DB) reader.Reader {
return generic.NewSqlReader(&storage{conn})
func NewStorage(conn *sql.DB, timeout time.Duration) reader.Reader {
return generic.NewSqlReader(&storage{conn}, timeout)
}

// GetConnection return the connection
Expand Down
2 changes: 1 addition & 1 deletion pkg/reader/postgres/postgres.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ func (m *driver) NewConnection(opts reader.ConnOpts) (reader.Reader, error) {
return nil, err
}

return NewStorage(conn, dump), nil
return NewStorage(conn, dump, opts.Timeout), nil
}

func init() {
Expand Down
14 changes: 7 additions & 7 deletions pkg/reader/postgres/reader.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package postgres
import (
"database/sql"
"strconv"
"time"

"github.com/hellofresh/klepto/pkg/reader"
"github.com/hellofresh/klepto/pkg/reader/generic"
Expand All @@ -17,13 +18,12 @@ type storage struct {
}

// NewStorage ...
func NewStorage(conn *sql.DB, dumper PgDump) reader.Reader {
return generic.NewSqlReader(
&storage{
PgDump: dumper,
connection: conn,
},
)
func NewStorage(conn *sql.DB, dumper PgDump, timeout time.Duration) reader.Reader {
s := &storage{
PgDump: dumper,
connection: conn,
}
return generic.NewSqlReader(s, timeout)
}

func (s *storage) GetConnection() *sql.DB {
Expand Down

0 comments on commit 2b27fed

Please sign in to comment.