Skip to content

Commit

Permalink
Support custom renewal statements in Postgres (hashicorp#2788)
Browse files Browse the repository at this point in the history
* Support custom renewal statements in Postgres

* Refactored out default/custom renew methods
  • Loading branch information
ConstantineXVI authored and briankassouf committed Jun 1, 2017
1 parent ed9ff08 commit d004ad7
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 11 deletions.
48 changes: 37 additions & 11 deletions plugins/database/postgresql/postgresql.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,12 @@ import (
"github.com/lib/pq"
)

const postgreSQLTypeName string = "postgres"
const (
postgreSQLTypeName string = "postgres"
defaultPostgresRenewSQL = `
ALTER ROLE "{{name}}" VALID UNTIL '{{expiration}}';
`
)

// New implements builtinplugins.BuiltinFactory
func New() (interface{}, error) {
Expand Down Expand Up @@ -141,31 +146,52 @@ func (p *PostgreSQL) CreateUser(statements dbplugin.Statements, usernamePrefix s
}

func (p *PostgreSQL) RenewUser(statements dbplugin.Statements, username string, expiration time.Time) error {
// Grab the lock
p.Lock()
defer p.Unlock()

renewStmts := statements.RenewStatements
if renewStmts == "" {
renewStmts = defaultPostgresRenewSQL
}

db, err := p.getConnection()
if err != nil {
return err
}

expirationStr, err := p.GenerateExpiration(expiration)
tx, err := db.Begin()
if err != nil {
return err
}
defer func() {
tx.Rollback()
}()

query := fmt.Sprintf(
"ALTER ROLE %s VALID UNTIL '%s';",
pq.QuoteIdentifier(username),
expirationStr)

stmt, err := db.Prepare(query)
expirationStr, err := p.GenerateExpiration(expiration)
if err != nil {
return err
}
defer stmt.Close()
if _, err := stmt.Exec(); err != nil {

for _, query := range strutil.ParseArbitraryStringSlice(renewStmts, ";") {
query = strings.TrimSpace(query)
if len(query) == 0 {
continue
}
stmt, err := tx.Prepare(dbutil.QueryHelper(query, map[string]string{
"name": username,
"expiration": expirationStr,
}))
if err != nil {
return err
}

defer stmt.Close()
if _, err := stmt.Exec(); err != nil {
return err
}
}

if err := tx.Commit(); err != nil {
return err
}

Expand Down
22 changes: 22 additions & 0 deletions plugins/database/postgresql/postgresql_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,28 @@ func TestPostgreSQL_RenewUser(t *testing.T) {
if err = testCredsExist(t, connURL, username, password); err != nil {
t.Fatalf("Could not connect with new credentials: %s", err)
}
statements.RenewStatements = defaultPostgresRenewSQL
username, password, err = db.CreateUser(statements, "test", time.Now().Add(2*time.Second))
if err != nil {
t.Fatalf("err: %s", err)
}

if err = testCredsExist(t, connURL, username, password); err != nil {
t.Fatalf("Could not connect with new credentials: %s", err)
}

err = db.RenewUser(statements, username, time.Now().Add(time.Minute))
if err != nil {
t.Fatalf("err: %s", err)
}

// Sleep longer than the inital expiration time
time.Sleep(2 * time.Second)

if err = testCredsExist(t, connURL, username, password); err != nil {
t.Fatalf("Could not connect with new credentials: %s", err)
}

}

func TestPostgreSQL_RevokeUser(t *testing.T) {
Expand Down

0 comments on commit d004ad7

Please sign in to comment.