From d004ad75dbe1d11e9884eb2eb4e61090602f876d Mon Sep 17 00:00:00 2001 From: Andrew Paulin Date: Thu, 1 Jun 2017 16:18:16 -0400 Subject: [PATCH] Support custom renewal statements in Postgres (#2788) * Support custom renewal statements in Postgres * Refactored out default/custom renew methods --- plugins/database/postgresql/postgresql.go | 48 ++++++++++++++----- .../database/postgresql/postgresql_test.go | 22 +++++++++ 2 files changed, 59 insertions(+), 11 deletions(-) diff --git a/plugins/database/postgresql/postgresql.go b/plugins/database/postgresql/postgresql.go index d60ef8bbe00c..69bfe3405959 100644 --- a/plugins/database/postgresql/postgresql.go +++ b/plugins/database/postgresql/postgresql.go @@ -16,7 +16,12 @@ import ( "github.com/lib/pq" ) -const postgreSQLTypeName string = "postgres" +const ( + postgreSQLTypeName string = "postgres" + defaultPostgresRenewSQL = ` +ALTER ROLE "{{name}}" VALID UNTIL '{{expiration}}'; +` +) // New implements builtinplugins.BuiltinFactory func New() (interface{}, error) { @@ -141,31 +146,52 @@ func (p *PostgreSQL) CreateUser(statements dbplugin.Statements, usernamePrefix s } func (p *PostgreSQL) RenewUser(statements dbplugin.Statements, username string, expiration time.Time) error { - // Grab the lock p.Lock() defer p.Unlock() + renewStmts := statements.RenewStatements + if renewStmts == "" { + renewStmts = defaultPostgresRenewSQL + } + db, err := p.getConnection() if err != nil { return err } - expirationStr, err := p.GenerateExpiration(expiration) + tx, err := db.Begin() if err != nil { return err } + defer func() { + tx.Rollback() + }() - query := fmt.Sprintf( - "ALTER ROLE %s VALID UNTIL '%s';", - pq.QuoteIdentifier(username), - expirationStr) - - stmt, err := db.Prepare(query) + expirationStr, err := p.GenerateExpiration(expiration) if err != nil { return err } - defer stmt.Close() - if _, err := stmt.Exec(); err != nil { + + for _, query := range strutil.ParseArbitraryStringSlice(renewStmts, ";") { + query = strings.TrimSpace(query) + if len(query) == 0 { + continue + } + stmt, err := tx.Prepare(dbutil.QueryHelper(query, map[string]string{ + "name": username, + "expiration": expirationStr, + })) + if err != nil { + return err + } + + defer stmt.Close() + if _, err := stmt.Exec(); err != nil { + return err + } + } + + if err := tx.Commit(); err != nil { return err } diff --git a/plugins/database/postgresql/postgresql_test.go b/plugins/database/postgresql/postgresql_test.go index 79391dc56ec0..3fd441bc59bd 100644 --- a/plugins/database/postgresql/postgresql_test.go +++ b/plugins/database/postgresql/postgresql_test.go @@ -170,6 +170,28 @@ func TestPostgreSQL_RenewUser(t *testing.T) { if err = testCredsExist(t, connURL, username, password); err != nil { t.Fatalf("Could not connect with new credentials: %s", err) } + statements.RenewStatements = defaultPostgresRenewSQL + username, password, err = db.CreateUser(statements, "test", time.Now().Add(2*time.Second)) + if err != nil { + t.Fatalf("err: %s", err) + } + + if err = testCredsExist(t, connURL, username, password); err != nil { + t.Fatalf("Could not connect with new credentials: %s", err) + } + + err = db.RenewUser(statements, username, time.Now().Add(time.Minute)) + if err != nil { + t.Fatalf("err: %s", err) + } + + // Sleep longer than the inital expiration time + time.Sleep(2 * time.Second) + + if err = testCredsExist(t, connURL, username, password); err != nil { + t.Fatalf("Could not connect with new credentials: %s", err) + } + } func TestPostgreSQL_RevokeUser(t *testing.T) {