Skip to content

Commit

Permalink
v1.5.0
Browse files Browse the repository at this point in the history
  • Loading branch information
stfnmllr committed Sep 8, 2023
1 parent d79ee08 commit e29656b
Show file tree
Hide file tree
Showing 8 changed files with 184 additions and 37 deletions.
10 changes: 9 additions & 1 deletion RELEASENOTES.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,14 @@
Release Notes
=============

## v1.5.0

### New features

- Added support of tenant database connection via tenant database name:
- see new Connector method WithDatabase and
- new DSN parameter DSNDatabaseName

## v1.4.0

### Minor revisions
Expand Down Expand Up @@ -224,7 +232,7 @@ Stored procedures:
- Calling stored procedures with sql.Query methods are no longer supported.
- Please use sql.Exec methods instead and [sql.Rows](https://golang.org/pkg/database/sql/#Rows) for table output parameters.

### New features:
### New features

- Stored procedures executed by sql.Exec with parameters do
- support [named](https://pkg.go.dev/database/sql#Named) parameters and
Expand Down
17 changes: 16 additions & 1 deletion driver/authattrs.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,27 @@ type authAttrs struct {
cbmu sync.RWMutex // prevents refresh callbacks from being called in parallel
}

func isJWTToken(token string) bool { return strings.HasPrefix(token, "ey") }

/*
keep c as the instance name, so that the generated help does have
the same instance variable name when included in connector
*/

func isJWTToken(token string) bool { return strings.HasPrefix(token, "ey") }
func (c *authAttrs) clone() *authAttrs {
c.mu.RLock()
defer c.mu.RUnlock()

return &authAttrs{
_username: c._username,
_password: c._password,
_certKey: c._certKey,
_token: c._token,
_refreshPassword: c._refreshPassword,
_refreshClientCert: c._refreshClientCert,
_refreshToken: c._refreshToken,
}
}

func (c *authAttrs) cookieAuth() *p.AuthHnd {
if !c.hasCookie.Load() { // fastpath without lock
Expand Down
9 changes: 9 additions & 0 deletions driver/connattrs.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ type connAttrs struct {
_cesu8Decoder func() transform.Transformer
_cesu8Encoder func() transform.Transformer
_emptyDateAsNull bool
_databaseName string
_logger *slog.Logger
}

Expand Down Expand Up @@ -120,6 +121,7 @@ func (c *connAttrs) clone() *connAttrs {
_cesu8Decoder: c._cesu8Decoder,
_cesu8Encoder: c._cesu8Encoder,
_emptyDateAsNull: c._emptyDateAsNull,
_databaseName: c._databaseName,
_logger: c._logger,
}
}
Expand Down Expand Up @@ -442,6 +444,13 @@ func (c *connAttrs) SetEmptyDateAsNull(emptyDateAsNull bool) {
c._emptyDateAsNull = emptyDateAsNull
}

// DatabaseName returns the tenant database name of the connector.
func (c *connAttrs) DatabaseName() string {
c.mu.RLock()
defer c.mu.RUnlock()
return c._databaseName
}

// Logger returns the Logger instance of the connector.
func (c *connAttrs) Logger() *slog.Logger {
c.mu.RLock()
Expand Down
84 changes: 64 additions & 20 deletions driver/connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (
"net"
"reflect"
"regexp"
"strconv"
"strings"
"sync/atomic"
"time"
Expand Down Expand Up @@ -205,10 +206,17 @@ func isAuthError(err error) bool {
return hdbErrors.Code() == p.HdbErrAuthenticationFailed
}

func newConn(ctx context.Context, metrics *metrics, connAttrs *connAttrs, authAttrs *authAttrs) (driver.Conn, error) {
func connect(ctx context.Context, metrics *metrics, connAttrs *connAttrs, authAttrs *authAttrs) (driver.Conn, error) {
// if database name fetch tenant database host
if connAttrs._databaseName != "" {
if err := fetchHost(ctx, metrics, connAttrs); err != nil {
return nil, err
}
}

// can we connect via cookie?
if auth := authAttrs.cookieAuth(); auth != nil {
conn, err := initConn(ctx, metrics, connAttrs, auth)
conn, err := newSession(ctx, metrics, connAttrs, auth)
if err == nil {
return conn, nil
}
Expand All @@ -225,7 +233,7 @@ func newConn(ctx context.Context, metrics *metrics, connAttrs *connAttrs, authAt
for {
authHnd := authAttrs.authHnd()

conn, err := initConn(ctx, metrics, connAttrs, authHnd)
conn, err := newSession(ctx, metrics, connAttrs, authHnd)
if err == nil {
if method, ok := authHnd.Selected().(auth.CookieGetter); ok {
authAttrs.setCookie(method.Cookie())
Expand Down Expand Up @@ -287,7 +295,7 @@ func (nvs namedValues) LogValue() slog.Value {
// unique connection number.
var connNo atomic.Uint64

func initConn(ctx context.Context, metrics *metrics, attrs *connAttrs, authHnd *p.AuthHnd) (driver.Conn, error) {
func newConn(ctx context.Context, metrics *metrics, attrs *connAttrs) (*conn, error) {
netConn, err := attrs._dialer.DialContext(ctx, attrs._host, dial.DialerOptions{Timeout: attrs._timeout, TCPKeepAlive: attrs._tcpKeepAlive})
if err != nil {
return nil, err
Expand All @@ -298,48 +306,84 @@ func initConn(ctx context.Context, metrics *metrics, attrs *connAttrs, authHnd *
netConn = tls.Client(netConn, attrs._tlsConfig)
}

no := connNo.Add(1)
logger := attrs._logger.With(slog.Uint64("conn", no))
logger := attrs._logger.With(slog.Uint64("conn", connNo.Add(1)))

dbConn := &dbConn{metrics: metrics, conn: netConn, timeout: attrs._timeout, logger: logger}
// buffer connection
rw := bufio.NewReadWriter(bufio.NewReaderSize(dbConn, attrs._bufferSize), bufio.NewWriterSize(dbConn, attrs._bufferSize))

c := &conn{metrics: metrics, connAttrs: attrs, dbConn: dbConn, sqlTrace: sqlTrace.Load(), logger: logger}

protTrace := protTrace.Load()
c.pw = p.NewWriter(rw.Writer, protTrace, logger, attrs._cesu8Encoder, attrs._sessionVariables) // write upstream

c := &conn{
metrics: metrics,
connAttrs: attrs,
dbConn: dbConn,
sqlTrace: sqlTrace.Load(),
logger: logger,
pw: p.NewWriter(rw.Writer, protTrace, logger, attrs._cesu8Encoder, attrs._sessionVariables), // write upstream
pr: p.NewDBReader(rw.Reader, protTrace, logger, attrs._cesu8Decoder), // read downstream
sessionID: defaultSessionID,
}

if err := c.pw.WriteProlog(ctx); err != nil {
dbConn.close()
return nil, err
}

c.pr = p.NewDBReader(rw.Reader, protTrace, logger, attrs._cesu8Decoder) // read downstream
if err := c.pr.ReadProlog(ctx); err != nil {
dbConn.close()
return nil, err
}

c.sessionID = defaultSessionID
c.metrics.chMsg <- gaugeMsg{idx: gaugeConn, v: 1} // increment open connections.
return c, nil
}

if c.sessionID, c.serverOptions, err = c._authenticate(ctx, authHnd, attrs); err != nil {
func fetchHost(ctx context.Context, metrics *metrics, attrs *connAttrs) error {
c, err := newConn(ctx, metrics, attrs)
if err != nil {
return err
}
defer c.Close()
dbi, err := c._dbConnectInfo(ctx, attrs._databaseName)
if err != nil {
return err
}
if !dbi.IsConnected { // if databaseName == "SYSTEMDB" and isConnected == true host and port are initial
attrs._host = net.JoinHostPort(dbi.Host, strconv.Itoa(dbi.Port))
}
return nil
}

func newSession(ctx context.Context, metrics *metrics, attrs *connAttrs, authHnd *p.AuthHnd) (driver.Conn, error) {
c, err := newConn(ctx, metrics, attrs)
if err != nil {
return nil, err
}
if err := c.initSession(ctx, attrs, authHnd); err != nil {
c.Close()
return nil, err
}
return c, nil
}

func (c *conn) initSession(ctx context.Context, attrs *connAttrs, authHnd *p.AuthHnd) (err error) {
if c.sessionID, c.serverOptions, err = c._authenticate(ctx, authHnd, attrs); err != nil {
return err
}
if c.sessionID <= 0 {
return nil, fmt.Errorf("invalid session id %d", c.sessionID)
return fmt.Errorf("invalid session id %d", c.sessionID)
}

c.hdbVersion = parseVersion(c.versionString())
c.fieldTypeCtx = p.NewFieldTypeCtx(int(c.serverOptions[p.CoDataFormatVersion2].(int32)), attrs._emptyDateAsNull)

if attrs._defaultSchema != "" {
if _, err := c.ExecContext(ctx, strings.Join([]string{setDefaultSchema, Identifier(attrs._defaultSchema).String()}, " "), nil); err != nil {
return nil, err
return err
}
}

c.metrics.chMsg <- gaugeMsg{idx: gaugeConn, v: 1} // increment open connections.

return c, nil
return nil
}

func (c *conn) versionString() (version string) {
Expand Down Expand Up @@ -435,8 +479,8 @@ func (c *conn) PrepareContext(ctx context.Context, query string) (stmt driver.St
// Close implements the driver.Conn interface.
func (c *conn) Close() error {
c.metrics.chMsg <- gaugeMsg{idx: gaugeConn, v: -1} // decrement open connections.
// if isBad do not disconnect
if !c.isBad() {
// do not disconnect if isBad or invalid sessionID
if !c.isBad() && c.sessionID != defaultSessionID {
c._disconnect(context.Background()) // ignore error
}
return c.dbConn.close()
Expand Down
26 changes: 21 additions & 5 deletions driver/connector.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,18 +17,17 @@ type Connector struct {
*connAttrs
*authAttrs

metrics *metrics

connHook func(driver.Conn) driver.Conn
newConn func(ctx context.Context, connAttrs *connAttrs, authAttrs *authAttrs) (driver.Conn, error)
}

// NewConnector returns a new Connector instance with default values.
func NewConnector() *Connector {
return &Connector{
connAttrs: newConnAttrs(),
authAttrs: &authAttrs{},
newConn: func(ctx context.Context, connAttrs *connAttrs, authAttrs *authAttrs) (driver.Conn, error) {
return newConn(ctx, stdHdbDriver.metrics, connAttrs, authAttrs) // use default stdHdbDriver metrics
},
metrics: stdHdbDriver.metrics, // use default stdHdbDriver metrics
}
}

Expand Down Expand Up @@ -77,6 +76,7 @@ func NewJWTAuthConnector(host, token string) *Connector {
func newDSNConnector(dsn *DSN) (*Connector, error) {
c := NewConnector()
c._host = dsn.host
c._databaseName = dsn.databaseName
c._pingInterval = dsn.pingInterval
c._defaultSchema = dsn.defaultSchema
c.setTimeout(dsn.timeout)
Expand Down Expand Up @@ -104,7 +104,7 @@ func (c *Connector) NativeDriver() Driver { return stdHdbDriver }

// Connect implements the database/sql/driver/Connector interface.
func (c *Connector) Connect(ctx context.Context) (driver.Conn, error) {
conn, err := c.newConn(ctx, c.connAttrs.clone(), c.authAttrs)
conn, err := connect(ctx, c.metrics, c.connAttrs.clone(), c.authAttrs)
if err != nil {
return nil, err
}
Expand All @@ -117,6 +117,22 @@ func (c *Connector) Connect(ctx context.Context) (driver.Conn, error) {
// Driver implements the database/sql/driver/Connector interface.
func (c *Connector) Driver() driver.Driver { return stdHdbDriver }

func (c *Connector) clone() *Connector {
return &Connector{
connAttrs: c.connAttrs.clone(),
authAttrs: c.authAttrs.clone(),
metrics: c.metrics,
connHook: c.connHook,
}
}

// WithDatabase returns a new Connector supporting tenant database connections via database name.
func (c *Connector) WithDatabase(databaseName string) *Connector {
nc := c.clone()
nc._databaseName = databaseName
return nc
}

// SetConnHook sets a function for intercepting connection creation.
// This is for internal use only and might be changed or disabled in future.
func (c *Connector) SetConnHook(fn func(driver.Conn) driver.Conn) { c.connHook = fn }
6 changes: 2 additions & 4 deletions driver/driver.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import (
)

// DriverVersion is the version number of the hdb driver.
const DriverVersion = "1.4.7"
const DriverVersion = "1.5.0"

// DriverName is the driver name to use with sql.Open for hdb databases.
const DriverName = "hdb"
Expand Down Expand Up @@ -100,9 +100,7 @@ func OpenDB(c *Connector) *DB {
nc := &Connector{
connAttrs: c.connAttrs,
authAttrs: c.authAttrs,
newConn: func(ctx context.Context, connAttrs *connAttrs, authAttrs *authAttrs) (driver.Conn, error) {
return newConn(ctx, metrics, connAttrs, authAttrs) // use db specific metrics
},
metrics: metrics, // use db specific metrics
}
return &DB{
metrics: metrics,
Expand Down
22 changes: 17 additions & 5 deletions driver/dsn.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (

// DSN parameters.
const (
DSNDatabaseName = "databaseName" // Tenant database name.
DSNDefaultSchema = "defaultSchema" // Database default schema.
DSNTimeout = "timeout" // Driver side connection timeout in seconds.
DSNPingInterval = "pingInterval" // Connection ping interval in seconds.
Expand Down Expand Up @@ -42,19 +43,21 @@ A DSN represents a parsed DSN string. A DSN string is an URL string with the fol
and optional query parameters (see DSN query parameters and DSN query default values).
Example:
Examples:
"hdb://myuser:mypassword@localhost:30015?timeout=60"
"hdb://myUser:myPassword@localhost:30015?databaseName=myTenantDatabaseName"
"hdb://myUser:myPassword@localhost:30015?timeout=60"
Examples TLS connection:
"hdb://myuser:mypassword@localhost:39013?TLSRootCAFile=trust.pem"
"hdb://myuser:mypassword@localhost:39013?TLSRootCAFile=trust.pem&TLSServerName=hostname"
"hdb://myuser:mypassword@localhost:39013?TLSInsecureSkipVerify"
"hdb://myUser:myPassword@localhost:39013?TLSRootCAFile=trust.pem"
"hdb://myUser:myPassword@localhost:39013?TLSRootCAFile=trust.pem&TLSServerName=hostname"
"hdb://myUser:myPassword@localhost:39013?TLSInsecureSkipVerify"
*/
type DSN struct {
host string
username, password string
databaseName string
defaultSchema string
timeout time.Duration
pingInterval time.Duration
Expand Down Expand Up @@ -120,6 +123,12 @@ func parseDSN(s string) (*DSN, error) {
default:
return nil, parameterNotSupportedError(k)

case DSNDatabaseName:
if len(v) != 1 {
return nil, invalidNumberOfParametersError(k, len(v), 1)
}
dsn.databaseName = v[0]

case DSNDefaultSchema:
if len(v) != 1 {
return nil, invalidNumberOfParametersError(k, len(v), 1)
Expand Down Expand Up @@ -187,6 +196,9 @@ func parseDSN(s string) (*DSN, error) {
// String reassembles the DSN into a valid DSN string.
func (dsn *DSN) String() string {
values := url.Values{}
if dsn.databaseName != "" {
values.Set(DSNDatabaseName, dsn.databaseName)
}
if dsn.defaultSchema != "" {
values.Set(DSNDefaultSchema, dsn.defaultSchema)
}
Expand Down
Loading

0 comments on commit e29656b

Please sign in to comment.