From 3d77a9a6f4376ca75ffd69861cae8b5fc2732b88 Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Mon, 19 Dec 2016 10:15:58 -0800 Subject: [PATCH 001/152] Begin work on database refactor --- builtin/logical/database/backend.go | 104 +++ builtin/logical/database/backend_test.go | 620 ++++++++++++++++++ builtin/logical/database/dbs/cassandra.go | 194 ++++++ builtin/logical/database/dbs/db.go | 56 ++ builtin/logical/database/dbs/postgresql.go | 336 ++++++++++ .../database/path_config_connection.go | 188 ++++++ builtin/logical/database/path_config_lease.go | 103 +++ builtin/logical/database/path_role_create.go | 120 ++++ builtin/logical/database/path_roles.go | 161 +++++ builtin/logical/database/secret_creds.go | 147 +++++ cli/commands.go | 2 + 11 files changed, 2031 insertions(+) create mode 100644 builtin/logical/database/backend.go create mode 100644 builtin/logical/database/backend_test.go create mode 100644 builtin/logical/database/dbs/cassandra.go create mode 100644 builtin/logical/database/dbs/db.go create mode 100644 builtin/logical/database/dbs/postgresql.go create mode 100644 builtin/logical/database/path_config_connection.go create mode 100644 builtin/logical/database/path_config_lease.go create mode 100644 builtin/logical/database/path_role_create.go create mode 100644 builtin/logical/database/path_roles.go create mode 100644 builtin/logical/database/secret_creds.go diff --git a/builtin/logical/database/backend.go b/builtin/logical/database/backend.go new file mode 100644 index 000000000000..8b7fa36700c5 --- /dev/null +++ b/builtin/logical/database/backend.go @@ -0,0 +1,104 @@ +package database + +import ( + "strings" + "sync" + + log "github.com/mgutz/logxi/v1" + + "github.com/hashicorp/vault/builtin/logical/database/dbs" + "github.com/hashicorp/vault/logical" + "github.com/hashicorp/vault/logical/framework" +) + +func Factory(conf *logical.BackendConfig) (logical.Backend, error) { + return Backend(conf).Setup(conf) +} + +func Backend(conf *logical.BackendConfig) *databaseBackend { + var b databaseBackend + b.Backend = &framework.Backend{ + Help: strings.TrimSpace(backendHelp), + + Paths: []*framework.Path{ + pathConfigConnection(&b), + pathConfigLease(&b), + pathListRoles(&b), + pathRoles(&b), + pathRoleCreate(&b), + }, + + Secrets: []*framework.Secret{ + secretCreds(&b), + }, + + Clean: b.resetAllDBs, + } + + b.logger = conf.Logger + b.connections = make(map[string]dbs.DatabaseType) + return &b +} + +type databaseBackend struct { + connections map[string]dbs.DatabaseType + logger log.Logger + + *framework.Backend + sync.RWMutex +} + +// resetAllDBs closes all connections from all database types +func (b *databaseBackend) resetAllDBs() { + b.logger.Trace("postgres/resetdb: enter") + defer b.logger.Trace("postgres/resetdb: exit") + + b.Lock() + defer b.Unlock() + + for _, db := range b.connections { + db.Close() + } +} + +// Lease returns the lease information +func (b *databaseBackend) Lease(s logical.Storage) (*configLease, error) { + entry, err := s.Get("config/lease") + if err != nil { + return nil, err + } + if entry == nil { + return nil, nil + } + + var result configLease + if err := entry.DecodeJSON(&result); err != nil { + return nil, err + } + + return &result, nil +} + +func (b *databaseBackend) Role(s logical.Storage, n string) (*roleEntry, error) { + entry, err := s.Get("role/" + n) + if err != nil { + return nil, err + } + if entry == nil { + return nil, nil + } + + var result roleEntry + if err := entry.DecodeJSON(&result); err != nil { + return nil, err + } + + return &result, nil +} + +const backendHelp = ` +The PostgreSQL backend dynamically generates database users. + +After mounting this backend, configure it using the endpoints within +the "config/" path. +` diff --git a/builtin/logical/database/backend_test.go b/builtin/logical/database/backend_test.go new file mode 100644 index 000000000000..a203c9b19145 --- /dev/null +++ b/builtin/logical/database/backend_test.go @@ -0,0 +1,620 @@ +package database + +import ( + "database/sql" + "encoding/json" + "fmt" + "log" + "os" + "path" + "reflect" + "sync" + "testing" + "time" + + "github.com/hashicorp/vault/logical" + logicaltest "github.com/hashicorp/vault/logical/testing" + "github.com/lib/pq" + "github.com/mitchellh/mapstructure" + "github.com/ory-am/dockertest" +) + +var ( + testImagePull sync.Once +) + +func prepareTestContainer(t *testing.T, s logical.Storage, b logical.Backend) (cid dockertest.ContainerID, retURL string) { + if os.Getenv("PG_URL") != "" { + return "", os.Getenv("PG_URL") + } + + // Without this the checks for whether the container has started seem to + // never actually pass. There's really no reason to expose the test + // containers, so don't. + dockertest.BindDockerToLocalhost = "yep" + + testImagePull.Do(func() { + dockertest.Pull("postgres") + }) + + cid, connErr := dockertest.ConnectToPostgreSQL(60, 500*time.Millisecond, func(connURL string) bool { + // This will cause a validation to run + resp, err := b.HandleRequest(&logical.Request{ + Storage: s, + Operation: logical.UpdateOperation, + Path: "config/connection", + Data: map[string]interface{}{ + "connection_url": connURL, + }, + }) + if err != nil || (resp != nil && resp.IsError()) { + // It's likely not up and running yet, so return false and try again + return false + } + if resp == nil { + t.Fatal("expected warning") + } + + retURL = connURL + return true + }) + + if connErr != nil { + t.Fatalf("could not connect to database: %v", connErr) + } + + return +} + +func cleanupTestContainer(t *testing.T, cid dockertest.ContainerID) { + err := cid.KillRemove() + if err != nil { + t.Fatal(err) + } +} + +func TestBackend_config_connection(t *testing.T) { + var resp *logical.Response + var err error + config := logical.TestBackendConfig() + config.StorageView = &logical.InmemStorage{} + b, err := Factory(config) + if err != nil { + t.Fatal(err) + } + + configData := map[string]interface{}{ + "connection_url": "sample_connection_url", + "value": "", + "max_open_connections": 9, + "max_idle_connections": 7, + "verify_connection": false, + } + + configReq := &logical.Request{ + Operation: logical.UpdateOperation, + Path: "config/connection", + Storage: config.StorageView, + Data: configData, + } + resp, err = b.HandleRequest(configReq) + if err != nil || (resp != nil && resp.IsError()) { + t.Fatalf("err:%s resp:%#v\n", err, resp) + } + + configReq.Operation = logical.ReadOperation + resp, err = b.HandleRequest(configReq) + if err != nil || (resp != nil && resp.IsError()) { + t.Fatalf("err:%s resp:%#v\n", err, resp) + } + + delete(configData, "verify_connection") + if !reflect.DeepEqual(configData, resp.Data) { + t.Fatalf("bad: expected:%#v\nactual:%#v\n", configData, resp.Data) + } +} + +func TestBackend_basic(t *testing.T) { + config := logical.TestBackendConfig() + config.StorageView = &logical.InmemStorage{} + b, err := Factory(config) + if err != nil { + t.Fatal(err) + } + + cid, connURL := prepareTestContainer(t, config.StorageView, b) + if cid != "" { + defer cleanupTestContainer(t, cid) + } + connData := map[string]interface{}{ + "connection_url": connURL, + } + + logicaltest.Test(t, logicaltest.TestCase{ + Backend: b, + Steps: []logicaltest.TestStep{ + testAccStepConfig(t, connData, false), + testAccStepCreateRole(t, "web", testRole, false), + testAccStepReadCreds(t, b, config.StorageView, "web", connURL), + }, + }) +} + +func TestBackend_roleCrud(t *testing.T) { + config := logical.TestBackendConfig() + config.StorageView = &logical.InmemStorage{} + b, err := Factory(config) + if err != nil { + t.Fatal(err) + } + + cid, connURL := prepareTestContainer(t, config.StorageView, b) + if cid != "" { + defer cleanupTestContainer(t, cid) + } + connData := map[string]interface{}{ + "connection_url": connURL, + } + + logicaltest.Test(t, logicaltest.TestCase{ + Backend: b, + Steps: []logicaltest.TestStep{ + testAccStepConfig(t, connData, false), + testAccStepCreateRole(t, "web", testRole, false), + testAccStepReadRole(t, "web", testRole), + testAccStepDeleteRole(t, "web"), + testAccStepReadRole(t, "web", ""), + }, + }) +} + +func TestBackend_BlockStatements(t *testing.T) { + config := logical.TestBackendConfig() + config.StorageView = &logical.InmemStorage{} + b, err := Factory(config) + if err != nil { + t.Fatal(err) + } + + cid, connURL := prepareTestContainer(t, config.StorageView, b) + if cid != "" { + defer cleanupTestContainer(t, cid) + } + connData := map[string]interface{}{ + "connection_url": connURL, + } + + jsonBlockStatement, err := json.Marshal(testBlockStatementRoleSlice) + if err != nil { + t.Fatal(err) + } + + logicaltest.Test(t, logicaltest.TestCase{ + Backend: b, + Steps: []logicaltest.TestStep{ + testAccStepConfig(t, connData, false), + // This will also validate the query + testAccStepCreateRole(t, "web-block", testBlockStatementRole, true), + testAccStepCreateRole(t, "web-block", string(jsonBlockStatement), false), + }, + }) +} + +func TestBackend_roleReadOnly(t *testing.T) { + config := logical.TestBackendConfig() + config.StorageView = &logical.InmemStorage{} + b, err := Factory(config) + if err != nil { + t.Fatal(err) + } + + cid, connURL := prepareTestContainer(t, config.StorageView, b) + if cid != "" { + defer cleanupTestContainer(t, cid) + } + connData := map[string]interface{}{ + "connection_url": connURL, + } + + logicaltest.Test(t, logicaltest.TestCase{ + Backend: b, + Steps: []logicaltest.TestStep{ + testAccStepConfig(t, connData, false), + testAccStepCreateRole(t, "web", testRole, false), + testAccStepCreateRole(t, "web-readonly", testReadOnlyRole, false), + testAccStepReadRole(t, "web-readonly", testReadOnlyRole), + testAccStepCreateTable(t, b, config.StorageView, "web", connURL), + testAccStepReadCreds(t, b, config.StorageView, "web-readonly", connURL), + testAccStepDropTable(t, b, config.StorageView, "web", connURL), + testAccStepDeleteRole(t, "web-readonly"), + testAccStepDeleteRole(t, "web"), + testAccStepReadRole(t, "web-readonly", ""), + }, + }) +} + +func TestBackend_roleReadOnly_revocationSQL(t *testing.T) { + config := logical.TestBackendConfig() + config.StorageView = &logical.InmemStorage{} + b, err := Factory(config) + if err != nil { + t.Fatal(err) + } + + cid, connURL := prepareTestContainer(t, config.StorageView, b) + if cid != "" { + defer cleanupTestContainer(t, cid) + } + connData := map[string]interface{}{ + "connection_url": connURL, + } + + logicaltest.Test(t, logicaltest.TestCase{ + Backend: b, + Steps: []logicaltest.TestStep{ + testAccStepConfig(t, connData, false), + testAccStepCreateRoleWithRevocationSQL(t, "web", testRole, defaultRevocationSQL, false), + testAccStepCreateRoleWithRevocationSQL(t, "web-readonly", testReadOnlyRole, defaultRevocationSQL, false), + testAccStepReadRole(t, "web-readonly", testReadOnlyRole), + testAccStepCreateTable(t, b, config.StorageView, "web", connURL), + testAccStepReadCreds(t, b, config.StorageView, "web-readonly", connURL), + testAccStepDropTable(t, b, config.StorageView, "web", connURL), + testAccStepDeleteRole(t, "web-readonly"), + testAccStepDeleteRole(t, "web"), + testAccStepReadRole(t, "web-readonly", ""), + }, + }) +} + +func testAccStepConfig(t *testing.T, d map[string]interface{}, expectError bool) logicaltest.TestStep { + return logicaltest.TestStep{ + Operation: logical.UpdateOperation, + Path: "config/connection", + Data: d, + ErrorOk: true, + Check: func(resp *logical.Response) error { + if expectError { + if resp.Data == nil { + return fmt.Errorf("data is nil") + } + var e struct { + Error string `mapstructure:"error"` + } + if err := mapstructure.Decode(resp.Data, &e); err != nil { + return err + } + if len(e.Error) == 0 { + return fmt.Errorf("expected error, but write succeeded.") + } + return nil + } else if resp != nil && resp.IsError() { + return fmt.Errorf("got an error response: %v", resp.Error()) + } + return nil + }, + } +} + +func testAccStepCreateRole(t *testing.T, name string, sql string, expectFail bool) logicaltest.TestStep { + return logicaltest.TestStep{ + Operation: logical.UpdateOperation, + Path: path.Join("roles", name), + Data: map[string]interface{}{ + "sql": sql, + }, + ErrorOk: expectFail, + } +} + +func testAccStepCreateRoleWithRevocationSQL(t *testing.T, name, sql, revocationSQL string, expectFail bool) logicaltest.TestStep { + return logicaltest.TestStep{ + Operation: logical.UpdateOperation, + Path: path.Join("roles", name), + Data: map[string]interface{}{ + "sql": sql, + "revocation_sql": revocationSQL, + }, + ErrorOk: expectFail, + } +} + +func testAccStepDeleteRole(t *testing.T, name string) logicaltest.TestStep { + return logicaltest.TestStep{ + Operation: logical.DeleteOperation, + Path: path.Join("roles", name), + } +} + +func testAccStepReadCreds(t *testing.T, b logical.Backend, s logical.Storage, name string, connURL string) logicaltest.TestStep { + return logicaltest.TestStep{ + Operation: logical.ReadOperation, + Path: path.Join("creds", name), + Check: func(resp *logical.Response) error { + var d struct { + Username string `mapstructure:"username"` + Password string `mapstructure:"password"` + } + if err := mapstructure.Decode(resp.Data, &d); err != nil { + return err + } + log.Printf("[TRACE] Generated credentials: %v", d) + conn, err := pq.ParseURL(connURL) + + if err != nil { + t.Fatal(err) + } + + conn += " timezone=utc" + + db, err := sql.Open("postgres", conn) + if err != nil { + t.Fatal(err) + } + + returnedRows := func() int { + stmt, err := db.Prepare("SELECT DISTINCT schemaname FROM pg_tables WHERE has_table_privilege($1, 'information_schema.role_column_grants', 'select');") + if err != nil { + return -1 + } + defer stmt.Close() + + rows, err := stmt.Query(d.Username) + if err != nil { + return -1 + } + defer rows.Close() + + i := 0 + for rows.Next() { + i++ + } + return i + } + + // minNumPermissions is the minimum number of permissions that will always be present. + const minNumPermissions = 2 + + userRows := returnedRows() + if userRows < minNumPermissions { + t.Fatalf("did not get expected number of rows, got %d", userRows) + } + + resp, err = b.HandleRequest(&logical.Request{ + Operation: logical.RevokeOperation, + Storage: s, + Secret: &logical.Secret{ + InternalData: map[string]interface{}{ + "secret_type": "creds", + "username": d.Username, + "role": name, + }, + }, + }) + if err != nil { + return err + } + if resp != nil { + if resp.IsError() { + return fmt.Errorf("Error on resp: %#v", *resp) + } + } + + userRows = returnedRows() + // User shouldn't exist so returnedRows() should encounter an error and exit with -1 + if userRows != -1 { + t.Fatalf("did not get expected number of rows, got %d", userRows) + } + + return nil + }, + } +} + +func testAccStepCreateTable(t *testing.T, b logical.Backend, s logical.Storage, name string, connURL string) logicaltest.TestStep { + return logicaltest.TestStep{ + Operation: logical.ReadOperation, + Path: path.Join("creds", name), + Check: func(resp *logical.Response) error { + var d struct { + Username string `mapstructure:"username"` + Password string `mapstructure:"password"` + } + if err := mapstructure.Decode(resp.Data, &d); err != nil { + return err + } + log.Printf("[TRACE] Generated credentials: %v", d) + conn, err := pq.ParseURL(connURL) + + if err != nil { + t.Fatal(err) + } + + conn += " timezone=utc" + + db, err := sql.Open("postgres", conn) + if err != nil { + t.Fatal(err) + } + + _, err = db.Exec("CREATE TABLE test (id SERIAL PRIMARY KEY);") + if err != nil { + t.Fatal(err) + } + + resp, err = b.HandleRequest(&logical.Request{ + Operation: logical.RevokeOperation, + Storage: s, + Secret: &logical.Secret{ + InternalData: map[string]interface{}{ + "secret_type": "creds", + "username": d.Username, + }, + }, + }) + if err != nil { + return err + } + if resp != nil { + if resp.IsError() { + return fmt.Errorf("Error on resp: %#v", *resp) + } + } + + return nil + }, + } +} + +func testAccStepDropTable(t *testing.T, b logical.Backend, s logical.Storage, name string, connURL string) logicaltest.TestStep { + return logicaltest.TestStep{ + Operation: logical.ReadOperation, + Path: path.Join("creds", name), + Check: func(resp *logical.Response) error { + var d struct { + Username string `mapstructure:"username"` + Password string `mapstructure:"password"` + } + if err := mapstructure.Decode(resp.Data, &d); err != nil { + return err + } + log.Printf("[TRACE] Generated credentials: %v", d) + conn, err := pq.ParseURL(connURL) + + if err != nil { + t.Fatal(err) + } + + conn += " timezone=utc" + + db, err := sql.Open("postgres", conn) + if err != nil { + t.Fatal(err) + } + + _, err = db.Exec("DROP TABLE test;") + if err != nil { + t.Fatal(err) + } + + resp, err = b.HandleRequest(&logical.Request{ + Operation: logical.RevokeOperation, + Storage: s, + Secret: &logical.Secret{ + InternalData: map[string]interface{}{ + "secret_type": "creds", + "username": d.Username, + }, + }, + }) + if err != nil { + return err + } + if resp != nil { + if resp.IsError() { + return fmt.Errorf("Error on resp: %#v", *resp) + } + } + + return nil + }, + } +} + +func testAccStepReadRole(t *testing.T, name string, sql string) logicaltest.TestStep { + return logicaltest.TestStep{ + Operation: logical.ReadOperation, + Path: "roles/" + name, + Check: func(resp *logical.Response) error { + if resp == nil { + if sql == "" { + return nil + } + + return fmt.Errorf("bad: %#v", resp) + } + + var d struct { + SQL string `mapstructure:"sql"` + } + if err := mapstructure.Decode(resp.Data, &d); err != nil { + return err + } + + if d.SQL != sql { + return fmt.Errorf("bad: %#v", resp) + } + + return nil + }, + } +} + +const testRole = ` +CREATE ROLE "{{name}}" WITH + LOGIN + PASSWORD '{{password}}' + VALID UNTIL '{{expiration}}'; +GRANT ALL PRIVILEGES ON ALL TABLES IN SCHEMA public TO "{{name}}"; +` + +const testReadOnlyRole = ` +CREATE ROLE "{{name}}" WITH + LOGIN + PASSWORD '{{password}}' + VALID UNTIL '{{expiration}}'; +GRANT SELECT ON ALL TABLES IN SCHEMA public TO "{{name}}"; +GRANT SELECT ON ALL SEQUENCES IN SCHEMA public TO "{{name}}"; +` + +const testBlockStatementRole = ` +DO $$ +BEGIN + IF NOT EXISTS (SELECT * FROM pg_catalog.pg_roles WHERE rolname='foo-role') THEN + CREATE ROLE "foo-role"; + CREATE SCHEMA IF NOT EXISTS foo AUTHORIZATION "foo-role"; + ALTER ROLE "foo-role" SET search_path = foo; + GRANT TEMPORARY ON DATABASE "postgres" TO "foo-role"; + GRANT ALL PRIVILEGES ON SCHEMA foo TO "foo-role"; + GRANT ALL PRIVILEGES ON ALL TABLES IN SCHEMA foo TO "foo-role"; + GRANT ALL PRIVILEGES ON ALL SEQUENCES IN SCHEMA foo TO "foo-role"; + GRANT ALL PRIVILEGES ON ALL FUNCTIONS IN SCHEMA foo TO "foo-role"; + END IF; +END +$$ + +CREATE ROLE "{{name}}" WITH LOGIN PASSWORD '{{password}}' VALID UNTIL '{{expiration}}'; +GRANT "foo-role" TO "{{name}}"; +ALTER ROLE "{{name}}" SET search_path = foo; +GRANT CONNECT ON DATABASE "postgres" TO "{{name}}"; +` + +var testBlockStatementRoleSlice = []string{ + ` +DO $$ +BEGIN + IF NOT EXISTS (SELECT * FROM pg_catalog.pg_roles WHERE rolname='foo-role') THEN + CREATE ROLE "foo-role"; + CREATE SCHEMA IF NOT EXISTS foo AUTHORIZATION "foo-role"; + ALTER ROLE "foo-role" SET search_path = foo; + GRANT TEMPORARY ON DATABASE "postgres" TO "foo-role"; + GRANT ALL PRIVILEGES ON SCHEMA foo TO "foo-role"; + GRANT ALL PRIVILEGES ON ALL TABLES IN SCHEMA foo TO "foo-role"; + GRANT ALL PRIVILEGES ON ALL SEQUENCES IN SCHEMA foo TO "foo-role"; + GRANT ALL PRIVILEGES ON ALL FUNCTIONS IN SCHEMA foo TO "foo-role"; + END IF; +END +$$ +`, + `CREATE ROLE "{{name}}" WITH LOGIN PASSWORD '{{password}}' VALID UNTIL '{{expiration}}';`, + `GRANT "foo-role" TO "{{name}}";`, + `ALTER ROLE "{{name}}" SET search_path = foo;`, + `GRANT CONNECT ON DATABASE "postgres" TO "{{name}}";`, +} + +const defaultRevocationSQL = ` +REVOKE ALL PRIVILEGES ON ALL TABLES IN SCHEMA public FROM {{name}}; +REVOKE ALL PRIVILEGES ON ALL SEQUENCES IN SCHEMA public FROM {{name}}; +REVOKE USAGE ON SCHEMA public FROM {{name}}; + +DROP ROLE IF EXISTS {{name}}; +` diff --git a/builtin/logical/database/dbs/cassandra.go b/builtin/logical/database/dbs/cassandra.go new file mode 100644 index 000000000000..8c7a068becfd --- /dev/null +++ b/builtin/logical/database/dbs/cassandra.go @@ -0,0 +1,194 @@ +package dbs + +import ( + "crypto/tls" + "database/sql" + "fmt" + "strings" + "sync" + "time" + + "github.com/gocql/gocql" + "github.com/hashicorp/vault/helper/certutil" + "github.com/hashicorp/vault/helper/tlsutil" +) + +type Cassandra struct { + // Session is goroutine safe, however, since we reinitialize + // it when connection info changes, we want to make sure we + // can close it and use a new connection; hence the lock + session *gocql.Session + config ConnectionConfig + + sync.RWMutex +} + +func (c *Cassandra) Type() string { + return cassandraTypeName +} + +func (c *Cassandra) Connection() (*gocql.Session, error) { + // Grab the write lock + c.Lock() + defer c.Unlock() + + // If we already have a DB, we got it! + if c.session != nil { + return c.session, nil + } + + session, err := createSession(c.config) + if err != nil { + return nil, err + } + + // Store the session in backend for reuse + c.session = session + + return session, nil +} + +func (p *Cassandra) Close() { + // Grab the write lock + p.Lock() + defer p.Unlock() + + if p.session != nil { + p.session.Close() + } + + p.session = nil +} + +func (p *Cassandra) Reset(config ConnectionConfig) (*sql.DB, error) { + // Grab the write lock + p.Lock() + p.config = config + p.Unlock() + + p.Close() + return p.Connection() +} + +func (p *Cassandra) CreateUser(createStmt, username, password, expiration string) error { + // Get the connection + db, err := p.Connection() + if err != nil { + return err + } + + // TODO: This is racey + // Grab a read lock + p.RLock() + defer p.RUnlock() + + return nil +} + +func (p *Cassandra) RenewUser(username, expiration string) error { + db, err := p.Connection() + if err != nil { + return err + } + // TODO: This is Racey + // Grab the read lock + p.RLock() + defer p.RUnlock() + + return nil +} + +func (p *Cassandra) CustomRevokeUser(username, revocationSQL string) error { + db, err := p.Connection() + if err != nil { + return err + } + // TODO: this is Racey + p.RLock() + defer p.RUnlock() + + return nil +} + +func (p *Cassandra) DefaultRevokeUser(username string) error { + // Grab the read lock + p.RLock() + defer p.RUnlock() + + db, err := p.Connection() + + return nil +} + +func createSession(cfg *ConnectionConfig) (*gocql.Session, error) { + clusterConfig := gocql.NewCluster(strings.Split(cfg.Hosts, ",")...) + clusterConfig.Authenticator = gocql.PasswordAuthenticator{ + Username: cfg.Username, + Password: cfg.Password, + } + + clusterConfig.ProtoVersion = cfg.ProtocolVersion + if clusterConfig.ProtoVersion == 0 { + clusterConfig.ProtoVersion = 2 + } + + clusterConfig.Timeout = time.Duration(cfg.ConnectTimeout) * time.Second + + if cfg.TLS { + var tlsConfig *tls.Config + if len(cfg.Certificate) > 0 || len(cfg.IssuingCA) > 0 { + if len(cfg.Certificate) > 0 && len(cfg.PrivateKey) == 0 { + return nil, fmt.Errorf("Found certificate for TLS authentication but no private key") + } + + certBundle := &certutil.CertBundle{} + if len(cfg.Certificate) > 0 { + certBundle.Certificate = cfg.Certificate + certBundle.PrivateKey = cfg.PrivateKey + } + if len(cfg.IssuingCA) > 0 { + certBundle.IssuingCA = cfg.IssuingCA + } + + parsedCertBundle, err := certBundle.ToParsedCertBundle() + if err != nil { + return nil, fmt.Errorf("failed to parse certificate bundle: %s", err) + } + + tlsConfig, err = parsedCertBundle.GetTLSConfig(certutil.TLSClient) + if err != nil || tlsConfig == nil { + return nil, fmt.Errorf("failed to get TLS configuration: tlsConfig:%#v err:%v", tlsConfig, err) + } + tlsConfig.InsecureSkipVerify = cfg.InsecureTLS + + if cfg.TLSMinVersion != "" { + var ok bool + tlsConfig.MinVersion, ok = tlsutil.TLSLookup[cfg.TLSMinVersion] + if !ok { + return nil, fmt.Errorf("invalid 'tls_min_version' in config") + } + } else { + // MinVersion was not being set earlier. Reset it to + // zero to gracefully handle upgrades. + tlsConfig.MinVersion = 0 + } + } + + clusterConfig.SslOpts = &gocql.SslOptions{ + Config: *tlsConfig, + } + } + + session, err := clusterConfig.CreateSession() + if err != nil { + return nil, fmt.Errorf("Error creating session: %s", err) + } + + // Verify the info + err = session.Query(`LIST USERS`).Exec() + if err != nil { + return nil, fmt.Errorf("Error validating connection info: %s", err) + } + + return session, nil +} diff --git a/builtin/logical/database/dbs/db.go b/builtin/logical/database/dbs/db.go new file mode 100644 index 000000000000..ee7b15b64d40 --- /dev/null +++ b/builtin/logical/database/dbs/db.go @@ -0,0 +1,56 @@ +package dbs + +import ( + "database/sql" + "errors" + "fmt" + "strings" +) + +const ( + postgreSQLTypeName = "postgres" + cassandraTypeName = "cassandra" +) + +var ( + ErrUnsupportedDatabaseType = errors.New("Unsupported database type") +) + +func Factory(conf ConnectionConfig) (DatabaseType, error) { + switch conf.ConnectionType { + case postgreSQLTypeName: + return &PostgreSQL{ + config: conf, + }, nil + } + + return nil, ErrUnsupportedDatabaseType +} + +type DatabaseType interface { + Type() string + Connection() (*sql.DB, error) + Close() + Reset(ConnectionConfig) (*sql.DB, error) + CreateUser(createStmt, username, password, expiration string) error + RenewUser(username, expiration string) error + CustomRevokeUser(username, revocationSQL string) error + DefaultRevokeUser(username string) error +} + +type ConnectionConfig struct { + ConnectionType string `json:"type" structs:"type" mapstructure:"type"` + ConnectionURL string `json:"connection_url" structs:"connection_url" mapstructure:"connection_url"` + ConnectionDetails map[string]string `json:"connection_details" structs:"connection_details" mapstructure:"connection_details"` + MaxOpenConnections int `json:"max_open_connections" structs:"max_open_connections" mapstructure:"max_open_connections"` + MaxIdleConnections int `json:"max_idle_connections" structs:"max_idle_connections" mapstructure:"max_idle_connections"` +} + +// Query templates a query for us. +func queryHelper(tpl string, data map[string]string) string { + for k, v := range data { + tpl = strings.Replace(tpl, fmt.Sprintf("{{%s}}", k), v, -1) + } + + return tpl +} diff --git a/builtin/logical/database/dbs/postgresql.go b/builtin/logical/database/dbs/postgresql.go new file mode 100644 index 000000000000..ea7d08f8ac78 --- /dev/null +++ b/builtin/logical/database/dbs/postgresql.go @@ -0,0 +1,336 @@ +package dbs + +import ( + "database/sql" + "fmt" + "strings" + "sync" + + "github.com/hashicorp/vault/helper/strutil" + "github.com/lib/pq" +) + +type PostgreSQL struct { + db *sql.DB + config ConnectionConfig + + sync.RWMutex +} + +func (p *PostgreSQL) Type() string { + return postgreSQLTypeName +} + +func (p *PostgreSQL) Connection() (*sql.DB, error) { + // Grab the write lock + p.Lock() + defer p.Unlock() + + // If we already have a DB, we got it! + if p.db != nil { + if err := p.db.Ping(); err == nil { + return p.db, nil + } + // If the ping was unsuccessful, close it and ignore errors as we'll be + // reestablishing anyways + p.db.Close() + } + + // Otherwise, attempt to make connection + conn := p.config.ConnectionURL + + // Ensure timezone is set to UTC for all the conenctions + if strings.HasPrefix(conn, "postgres://") || strings.HasPrefix(conn, "postgresql://") { + if strings.Contains(conn, "?") { + conn += "&timezone=utc" + } else { + conn += "?timezone=utc" + } + } else { + conn += " timezone=utc" + } + + var err error + p.db, err = sql.Open("postgres", conn) + if err != nil { + return nil, err + } + + // Set some connection pool settings. We don't need much of this, + // since the request rate shouldn't be high. + p.db.SetMaxOpenConns(p.config.MaxOpenConnections) + p.db.SetMaxIdleConns(p.config.MaxIdleConnections) + + return p.db, nil +} + +func (p *PostgreSQL) Close() { + // Grab the write lock + p.Lock() + defer p.Unlock() + + if p.db != nil { + p.db.Close() + } + + p.db = nil +} + +func (p *PostgreSQL) Reset(config ConnectionConfig) (*sql.DB, error) { + // Grab the write lock + p.Lock() + p.config = config + p.Unlock() + + p.Close() + return p.Connection() +} + +func (p *PostgreSQL) CreateUser(createStmt, username, password, expiration string) error { + // Get the connection + db, err := p.Connection() + if err != nil { + return err + } + + // TODO: This is racey + // Grab a read lock + p.RLock() + defer p.RUnlock() + + // Start a transaction + // b.logger.Trace("postgres/pathRoleCreateRead: starting transaction") + tx, err := db.Begin() + if err != nil { + return err + } + defer func() { + // b.logger.Trace("postgres/pathRoleCreateRead: rolling back transaction") + tx.Rollback() + }() + // Return the secret + + // Execute each query + for _, query := range strutil.ParseArbitraryStringSlice(createStmt, ";") { + query = strings.TrimSpace(query) + if len(query) == 0 { + continue + } + + // b.logger.Trace("postgres/pathRoleCreateRead: preparing statement") + stmt, err := tx.Prepare(queryHelper(query, map[string]string{ + "name": username, + "password": password, + "expiration": expiration, + })) + if err != nil { + return err + } + defer stmt.Close() + // b.logger.Trace("postgres/pathRoleCreateRead: executing statement") + if _, err := stmt.Exec(); err != nil { + return err + } + } + + // Commit the transaction + + // b.logger.Trace("postgres/pathRoleCreateRead: committing transaction") + if err := tx.Commit(); err != nil { + return err + } + + return nil +} + +func (p *PostgreSQL) RenewUser(username, expiration string) error { + db, err := p.Connection() + if err != nil { + return err + } + // TODO: This is Racey + // Grab the read lock + p.RLock() + defer p.RUnlock() + + query := fmt.Sprintf( + "ALTER ROLE %s VALID UNTIL '%s';", + pq.QuoteIdentifier(username), + expiration) + + stmt, err := db.Prepare(query) + if err != nil { + return err + } + defer stmt.Close() + if _, err := stmt.Exec(); err != nil { + return err + } + + return nil +} + +func (p *PostgreSQL) CustomRevokeUser(username, revocationSQL string) error { + db, err := p.Connection() + if err != nil { + return err + } + // TODO: this is Racey + p.RLock() + defer p.RUnlock() + + tx, err := db.Begin() + if err != nil { + return err + } + defer func() { + tx.Rollback() + }() + + for _, query := range strutil.ParseArbitraryStringSlice(revocationSQL, ";") { + query = strings.TrimSpace(query) + if len(query) == 0 { + continue + } + + stmt, err := tx.Prepare(queryHelper(query, map[string]string{ + "name": username, + })) + if err != nil { + return err + } + defer stmt.Close() + + if _, err := stmt.Exec(); err != nil { + return err + } + } + + if err := tx.Commit(); err != nil { + return err + } + + return nil +} + +func (p *PostgreSQL) DefaultRevokeUser(username string) error { + // Grab the read lock + p.RLock() + defer p.RUnlock() + + db, err := p.Connection() + if err != nil { + return err + } + + // Check if the role exists + var exists bool + err = db.QueryRow("SELECT exists (SELECT rolname FROM pg_roles WHERE rolname=$1);", username).Scan(&exists) + if err != nil && err != sql.ErrNoRows { + return err + } + + if exists == false { + return nil + } + + // Query for permissions; we need to revoke permissions before we can drop + // the role + // This isn't done in a transaction because even if we fail along the way, + // we want to remove as much access as possible + stmt, err := db.Prepare("SELECT DISTINCT table_schema FROM information_schema.role_column_grants WHERE grantee=$1;") + if err != nil { + return err + } + defer stmt.Close() + + rows, err := stmt.Query(username) + if err != nil { + return err + } + defer rows.Close() + + const initialNumRevocations = 16 + revocationStmts := make([]string, 0, initialNumRevocations) + for rows.Next() { + var schema string + err = rows.Scan(&schema) + if err != nil { + // keep going; remove as many permissions as possible right now + continue + } + revocationStmts = append(revocationStmts, fmt.Sprintf( + `REVOKE ALL PRIVILEGES ON ALL TABLES IN SCHEMA %s FROM %s;`, + pq.QuoteIdentifier(schema), + pq.QuoteIdentifier(username))) + + revocationStmts = append(revocationStmts, fmt.Sprintf( + `REVOKE USAGE ON SCHEMA %s FROM %s;`, + pq.QuoteIdentifier(schema), + pq.QuoteIdentifier(username))) + } + + // for good measure, revoke all privileges and usage on schema public + revocationStmts = append(revocationStmts, fmt.Sprintf( + `REVOKE ALL PRIVILEGES ON ALL TABLES IN SCHEMA public FROM %s;`, + pq.QuoteIdentifier(username))) + + revocationStmts = append(revocationStmts, fmt.Sprintf( + "REVOKE ALL PRIVILEGES ON ALL SEQUENCES IN SCHEMA public FROM %s;", + pq.QuoteIdentifier(username))) + + revocationStmts = append(revocationStmts, fmt.Sprintf( + "REVOKE USAGE ON SCHEMA public FROM %s;", + pq.QuoteIdentifier(username))) + + // get the current database name so we can issue a REVOKE CONNECT for + // this username + var dbname sql.NullString + if err := db.QueryRow("SELECT current_database();").Scan(&dbname); err != nil { + return err + } + + if dbname.Valid { + revocationStmts = append(revocationStmts, fmt.Sprintf( + `REVOKE CONNECT ON DATABASE %s FROM %s;`, + pq.QuoteIdentifier(dbname.String), + pq.QuoteIdentifier(username))) + } + + // again, here, we do not stop on error, as we want to remove as + // many permissions as possible right now + var lastStmtError error + for _, query := range revocationStmts { + stmt, err := db.Prepare(query) + if err != nil { + lastStmtError = err + continue + } + defer stmt.Close() + _, err = stmt.Exec() + if err != nil { + lastStmtError = err + } + } + + // can't drop if not all privileges are revoked + if rows.Err() != nil { + return fmt.Errorf("could not generate revocation statements for all rows: %s", rows.Err()) + } + if lastStmtError != nil { + return fmt.Errorf("could not perform all revocation statements: %s", lastStmtError) + } + + // Drop this user + stmt, err = db.Prepare(fmt.Sprintf( + `DROP ROLE IF EXISTS %s;`, pq.QuoteIdentifier(username))) + if err != nil { + return err + } + defer stmt.Close() + if _, err := stmt.Exec(); err != nil { + return err + } + + return nil +} diff --git a/builtin/logical/database/path_config_connection.go b/builtin/logical/database/path_config_connection.go new file mode 100644 index 000000000000..be017ea35c02 --- /dev/null +++ b/builtin/logical/database/path_config_connection.go @@ -0,0 +1,188 @@ +package database + +import ( + "fmt" + + "github.com/fatih/structs" + "github.com/hashicorp/vault/builtin/logical/database/dbs" + "github.com/hashicorp/vault/logical" + "github.com/hashicorp/vault/logical/framework" + _ "github.com/lib/pq" +) + +func pathConfigConnection(b *databaseBackend) *framework.Path { + return &framework.Path{ + Pattern: fmt.Sprintf("dbs/%s", framework.GenericNameRegex("name")), + Fields: map[string]*framework.FieldSchema{ + "name": &framework.FieldSchema{ + Type: framework.TypeString, + Description: "Name of this DB type", + }, + + "connection_type": &framework.FieldSchema{ + Type: framework.TypeString, + Description: "DB type (e.g. postgres)", + }, + + "connection_url": &framework.FieldSchema{ + Type: framework.TypeString, + Description: "DB connection string", + }, + + "connection_details": &framework.FieldSchema{ + Type: framework.TypeMap, + Description: "Connection details for specified connection type.", + }, + + "verify_connection": &framework.FieldSchema{ + Type: framework.TypeBool, + Default: true, + Description: `If set, connection_url is verified by actually connecting to the database`, + }, + + "max_open_connections": &framework.FieldSchema{ + Type: framework.TypeInt, + Description: `Maximum number of open connections to the database; +a zero uses the default value of two and a +negative value means unlimited`, + }, + + "max_idle_connections": &framework.FieldSchema{ + Type: framework.TypeInt, + Description: `Maximum number of idle connections to the database; +a zero uses the value of max_open_connections +and a negative value disables idle connections. +If larger than max_open_connections it will be +reduced to the same size.`, + }, + }, + + Callbacks: map[logical.Operation]framework.OperationFunc{ + logical.UpdateOperation: b.pathConnectionWrite, + logical.ReadOperation: b.pathConnectionRead, + }, + + HelpSynopsis: pathConfigConnectionHelpSyn, + HelpDescription: pathConfigConnectionHelpDesc, + } +} + +// pathConnectionRead reads out the connection configuration +func (b *databaseBackend) pathConnectionRead(req *logical.Request, data *framework.FieldData) (*logical.Response, error) { + name := data.Get("name").(string) + + entry, err := req.Storage.Get(fmt.Sprintf("dbs/%s", name)) + if err != nil { + return nil, fmt.Errorf("failed to read connection configuration") + } + if entry == nil { + return nil, nil + } + + var config dbs.ConnectionConfig + if err := entry.DecodeJSON(&config); err != nil { + return nil, err + } + return &logical.Response{ + Data: structs.New(config).Map(), + }, nil +} + +func (b *databaseBackend) pathConnectionWrite(req *logical.Request, data *framework.FieldData) (*logical.Response, error) { + connURL := data.Get("connection_url").(string) + connType := data.Get("connection_type").(string) + + maxOpenConns := data.Get("max_open_connections").(int) + if maxOpenConns == 0 { + maxOpenConns = 2 + } + + maxIdleConns := data.Get("max_idle_connections").(int) + if maxIdleConns == 0 { + maxIdleConns = maxOpenConns + } + if maxIdleConns > maxOpenConns { + maxIdleConns = maxOpenConns + } + + config := dbs.ConnectionConfig{ + ConnectionType: connType, + ConnectionURL: connURL, + MaxOpenConnections: maxOpenConns, + MaxIdleConnections: maxIdleConns, + } + + name := data.Get("name").(string) + + // Grab the mutex lock + b.Lock() + defer b.Unlock() + + var err error + var db dbs.DatabaseType + if _, ok := b.connections[name]; ok { + + // Don't allow the connection type to change + if b.connections[name].Type() != connType { + return logical.ErrorResponse("can not change type of existing connection"), nil + } + + db = b.connections[name] + } else { + db, err = dbs.Factory(config) + if err != nil { + return logical.ErrorResponse(fmt.Sprintf("Error creating database object: %s", err)), nil + } + } + + /* + // Don't check the connection_url if verification is disabled + verifyConnection := data.Get("verify_connection").(bool) + if verifyConnection { + // Verify the string + db, err := sql.Open("postgres", connURL) + if err != nil { + return logical.ErrorResponse(fmt.Sprintf( + "Error validating connection info: %s", err)), nil + } + defer db.Close() + if err := db.Ping(); err != nil { + return logical.ErrorResponse(fmt.Sprintf( + "Error validating connection info: %s", err)), nil + } + } + */ + + // Store it + entry, err := logical.StorageEntryJSON(fmt.Sprintf("dbs/%s", name), config) + if err != nil { + return nil, err + } + if err := req.Storage.Put(entry); err != nil { + return nil, err + } + + // Reset the DB connection + db.Reset(config) + b.connections[name] = db + + resp := &logical.Response{} + resp.AddWarning("Read access to this endpoint should be controlled via ACLs as it will return the connection string or URL as it is, including passwords, if any.") + + return resp, nil +} + +const pathConfigConnectionHelpSyn = ` +Configure the connection string to talk to PostgreSQL. +` + +const pathConfigConnectionHelpDesc = ` +This path configures the connection string used to connect to PostgreSQL. +The value of the string can be a URL, or a PG style string in the +format of "user=foo host=bar" etc. + +The URL looks like: +"postgresql://user:pass@host:port/dbname" + +When configuring the connection string, the backend will verify its validity. +` diff --git a/builtin/logical/database/path_config_lease.go b/builtin/logical/database/path_config_lease.go new file mode 100644 index 000000000000..5cc40a056e9d --- /dev/null +++ b/builtin/logical/database/path_config_lease.go @@ -0,0 +1,103 @@ +package database + +import ( + "fmt" + "time" + + "github.com/hashicorp/vault/logical" + "github.com/hashicorp/vault/logical/framework" +) + +func pathConfigLease(b *databaseBackend) *framework.Path { + return &framework.Path{ + Pattern: "config/lease", + Fields: map[string]*framework.FieldSchema{ + "lease": &framework.FieldSchema{ + Type: framework.TypeString, + Description: "Default lease for roles.", + }, + + "lease_max": &framework.FieldSchema{ + Type: framework.TypeString, + Description: "Maximum time a credential is valid for.", + }, + }, + + Callbacks: map[logical.Operation]framework.OperationFunc{ + logical.ReadOperation: b.pathLeaseRead, + logical.UpdateOperation: b.pathLeaseWrite, + }, + + HelpSynopsis: pathConfigLeaseHelpSyn, + HelpDescription: pathConfigLeaseHelpDesc, + } +} + +func (b *databaseBackend) pathLeaseWrite( + req *logical.Request, d *framework.FieldData) (*logical.Response, error) { + leaseRaw := d.Get("lease").(string) + leaseMaxRaw := d.Get("lease_max").(string) + + lease, err := time.ParseDuration(leaseRaw) + if err != nil { + return logical.ErrorResponse(fmt.Sprintf( + "Invalid lease: %s", err)), nil + } + leaseMax, err := time.ParseDuration(leaseMaxRaw) + if err != nil { + return logical.ErrorResponse(fmt.Sprintf( + "Invalid lease: %s", err)), nil + } + + // Store it + entry, err := logical.StorageEntryJSON("config/lease", &configLease{ + Lease: lease, + LeaseMax: leaseMax, + }) + if err != nil { + return nil, err + } + if err := req.Storage.Put(entry); err != nil { + return nil, err + } + + return nil, nil +} + +func (b *databaseBackend) pathLeaseRead( + req *logical.Request, data *framework.FieldData) (*logical.Response, error) { + lease, err := b.Lease(req.Storage) + + if err != nil { + return nil, err + } + if lease == nil { + return nil, nil + } + + return &logical.Response{ + Data: map[string]interface{}{ + "lease": lease.Lease.String(), + "lease_max": lease.LeaseMax.String(), + }, + }, nil +} + +type configLease struct { + Lease time.Duration + LeaseMax time.Duration +} + +const pathConfigLeaseHelpSyn = ` +Configure the default lease information for generated credentials. +` + +const pathConfigLeaseHelpDesc = ` +This configures the default lease information used for credentials +generated by this backend. The lease specifies the duration that a +credential will be valid for, as well as the maximum session for +a set of credentials. + +The format for the lease is "1h" or integer and then unit. The longest +unit is hour. +` diff --git a/builtin/logical/database/path_role_create.go b/builtin/logical/database/path_role_create.go new file mode 100644 index 000000000000..2a2386d01213 --- /dev/null +++ b/builtin/logical/database/path_role_create.go @@ -0,0 +1,120 @@ +package database + +import ( + "fmt" + "time" + + "github.com/hashicorp/go-uuid" + "github.com/hashicorp/vault/logical" + "github.com/hashicorp/vault/logical/framework" + _ "github.com/lib/pq" +) + +func pathRoleCreate(b *databaseBackend) *framework.Path { + return &framework.Path{ + Pattern: "creds/" + framework.GenericNameRegex("name"), + Fields: map[string]*framework.FieldSchema{ + "name": &framework.FieldSchema{ + Type: framework.TypeString, + Description: "Name of the role.", + }, + }, + + Callbacks: map[logical.Operation]framework.OperationFunc{ + logical.ReadOperation: b.pathRoleCreateRead, + }, + + HelpSynopsis: pathRoleCreateReadHelpSyn, + HelpDescription: pathRoleCreateReadHelpDesc, + } +} + +func (b *databaseBackend) pathRoleCreateRead(req *logical.Request, data *framework.FieldData) (*logical.Response, error) { + b.logger.Trace("postgres/pathRoleCreateRead: enter") + defer b.logger.Trace("postgres/pathRoleCreateRead: exit") + + name := data.Get("name").(string) + + // Get the role + b.logger.Trace("postgres/pathRoleCreateRead: getting role") + role, err := b.Role(req.Storage, name) + if err != nil { + return nil, err + } + if role == nil { + return logical.ErrorResponse(fmt.Sprintf("unknown role: %s", name)), nil + } + + // Determine if we have a lease + b.logger.Trace("postgres/pathRoleCreateRead: getting lease") + lease, err := b.Lease(req.Storage) + if err != nil { + return nil, err + } + // Unlike some other backends we need a lease here (can't leave as 0 and + // let core fill it in) because Postgres also expires users as a safety + // measure, so cannot be zero + if lease == nil { + lease = &configLease{ + Lease: b.System().DefaultLeaseTTL(), + } + } + + // Generate the username, password and expiration. PG limits user to 63 characters + displayName := req.DisplayName + if len(displayName) > 26 { + displayName = displayName[:26] + } + userUUID, err := uuid.GenerateUUID() + if err != nil { + return nil, err + } + username := fmt.Sprintf("%s-%s", displayName, userUUID) + if len(username) > 63 { + username = username[:63] + } + password, err := uuid.GenerateUUID() + if err != nil { + return nil, err + } + expiration := time.Now(). + Add(lease.Lease). + Format("2006-01-02 15:04:05-0700") + + // Get our handle + b.logger.Trace("postgres/pathRoleCreateRead: getting database handle") + + b.RLock() + defer b.RUnlock() + db, ok := b.connections[role.DBName] + if !ok { + // TODO: return a resp error instead? + return nil, fmt.Errorf("Cound not find DB with name: %s", role.DBName) + } + + err = db.CreateUser(role.CreationStatement, username, password, expiration) + if err != nil { + return nil, err + } + + b.logger.Trace("postgres/pathRoleCreateRead: generating secret") + resp := b.Secret(SecretCredsType).Response(map[string]interface{}{ + "username": username, + "password": password, + }, map[string]interface{}{ + "username": username, + "role": name, + }) + resp.Secret.TTL = lease.Lease + return resp, nil +} + +const pathRoleCreateReadHelpSyn = ` +Request database credentials for a certain role. +` + +const pathRoleCreateReadHelpDesc = ` +This path reads database credentials for a certain role. The +database credentials will be generated on demand and will be automatically +revoked when the lease is up. +` diff --git a/builtin/logical/database/path_roles.go b/builtin/logical/database/path_roles.go new file mode 100644 index 000000000000..e06518b289a4 --- /dev/null +++ b/builtin/logical/database/path_roles.go @@ -0,0 +1,161 @@ +package database + +import ( + "github.com/hashicorp/vault/logical" + "github.com/hashicorp/vault/logical/framework" +) + +func pathListRoles(b *databaseBackend) *framework.Path { + return &framework.Path{ + Pattern: "roles/?$", + + Callbacks: map[logical.Operation]framework.OperationFunc{ + logical.ListOperation: b.pathRoleList, + }, + + HelpSynopsis: pathRoleHelpSyn, + HelpDescription: pathRoleHelpDesc, + } +} + +func pathRoles(b *databaseBackend) *framework.Path { + return &framework.Path{ + Pattern: "roles/" + framework.GenericNameRegex("name"), + Fields: map[string]*framework.FieldSchema{ + "name": { + Type: framework.TypeString, + Description: "Name of the role.", + }, + + "db_name": { + Type: framework.TypeString, + Description: "Name of the database this role acts on.", + }, + + "creation_statement": { + Type: framework.TypeString, + Description: "SQL string to create a user. See help for more info.", + }, + + "revocation_statement": { + Type: framework.TypeString, + Description: `SQL statements to be executed to revoke a user. Must be a semicolon-separated + string, a base64-encoded semicolon-separated string, a serialized JSON string + array, or a base64-encoded serialized JSON string array. The '{{name}}' value + will be substituted.`, + }, + }, + + Callbacks: map[logical.Operation]framework.OperationFunc{ + logical.ReadOperation: b.pathRoleRead, + logical.UpdateOperation: b.pathRoleCreate, + logical.DeleteOperation: b.pathRoleDelete, + }, + + HelpSynopsis: pathRoleHelpSyn, + HelpDescription: pathRoleHelpDesc, + } +} + +func (b *databaseBackend) pathRoleDelete(req *logical.Request, data *framework.FieldData) (*logical.Response, error) { + err := req.Storage.Delete("role/" + data.Get("name").(string)) + if err != nil { + return nil, err + } + + return nil, nil +} + +func (b *databaseBackend) pathRoleRead(req *logical.Request, data *framework.FieldData) (*logical.Response, error) { + role, err := b.Role(req.Storage, data.Get("name").(string)) + if err != nil { + return nil, err + } + if role == nil { + return nil, nil + } + + return &logical.Response{ + Data: map[string]interface{}{ + "creation_statment": role.CreationStatement, + "revocation_statement": role.RevocationStatement, + }, + }, nil +} + +func (b *databaseBackend) pathRoleList(req *logical.Request, d *framework.FieldData) (*logical.Response, error) { + entries, err := req.Storage.List("role/") + if err != nil { + return nil, err + } + + return logical.ListResponse(entries), nil +} + +func (b *databaseBackend) pathRoleCreate(req *logical.Request, data *framework.FieldData) (*logical.Response, error) { + name := data.Get("name").(string) + dbName := data.Get("db_name").(string) + creationStmt := data.Get("creation_statement").(string) + revocationStmt := data.Get("revocation_statement").(string) + + // TODO: Think about preparing the statments to test. + + // Store it + entry, err := logical.StorageEntryJSON("role/"+name, &roleEntry{ + DBName: dbName, + CreationStatement: creationStmt, + RevocationStatement: revocationStmt, + }) + if err != nil { + return nil, err + } + if err := req.Storage.Put(entry); err != nil { + return nil, err + } + + return nil, nil +} + +type roleEntry struct { + DBName string `json:"db_name" mapstructure:"db_name" structs:"db_name"` + CreationStatement string `json:"creation_statment" mapstructure:"creation_statement" structs:"creation_statment"` + RevocationStatement string `json:"revocation_statement" mapstructure:"revocation_statement" structs:"revocation_statement"` +} + +const pathRoleHelpSyn = ` +Manage the roles that can be created with this backend. +` + +const pathRoleHelpDesc = ` +This path lets you manage the roles that can be created with this backend. + +The "sql" parameter customizes the SQL string used to create the role. +This can be a sequence of SQL queries. Some substitution will be done to the +SQL string for certain keys. The names of the variables must be surrounded +by "{{" and "}}" to be replaced. + + * "name" - The random username generated for the DB user. + + * "password" - The random password generated for the DB user. + + * "expiration" - The timestamp when this user will expire. + +Example of a decent SQL query to use: + + CREATE ROLE "{{name}}" WITH + LOGIN + PASSWORD '{{password}}' + VALID UNTIL '{{expiration}}'; + GRANT ALL PRIVILEGES ON ALL TABLES IN SCHEMA public TO "{{name}}"; + +Note the above user would be able to access everything in schema public. +For more complex GRANT clauses, see the PostgreSQL manual. + +The "revocation_sql" parameter customizes the SQL string used to revoke a user. +Example of a decent revocation SQL query to use: + + REVOKE ALL PRIVILEGES ON ALL TABLES IN SCHEMA public FROM {{name}}; + REVOKE ALL PRIVILEGES ON ALL SEQUENCES IN SCHEMA public FROM {{name}}; + REVOKE USAGE ON SCHEMA public FROM {{name}}; + DROP ROLE IF EXISTS {{name}}; +` diff --git a/builtin/logical/database/secret_creds.go b/builtin/logical/database/secret_creds.go new file mode 100644 index 000000000000..30c4a6430f22 --- /dev/null +++ b/builtin/logical/database/secret_creds.go @@ -0,0 +1,147 @@ +package database + +import ( + "errors" + "fmt" + + "github.com/hashicorp/vault/logical" + "github.com/hashicorp/vault/logical/framework" +) + +const SecretCredsType = "creds" + +func secretCreds(b *databaseBackend) *framework.Secret { + return &framework.Secret{ + Type: SecretCredsType, + Fields: map[string]*framework.FieldSchema{ + "username": &framework.FieldSchema{ + Type: framework.TypeString, + Description: "Username", + }, + + "password": &framework.FieldSchema{ + Type: framework.TypeString, + Description: "Password", + }, + }, + + Renew: b.secretCredsRenew, + Revoke: b.secretCredsRevoke, + } +} + +func (b *databaseBackend) secretCredsRenew(req *logical.Request, d *framework.FieldData) (*logical.Response, error) { + dbName := d.Get("name").(string) + + // Get the username from the internal data + usernameRaw, ok := req.Secret.InternalData["username"] + if !ok { + return nil, fmt.Errorf("secret is missing username internal data") + } + username, ok := usernameRaw.(string) + + // Get our connection + db, ok := b.connections[dbName] + if !ok { + return nil, errors.New(fmt.Sprintf("Could not find connection with name %s", dbName)) + } + + // Get the lease information + lease, err := b.Lease(req.Storage) + if err != nil { + return nil, err + } + if lease == nil { + lease = &configLease{} + } + + f := framework.LeaseExtend(lease.Lease, lease.LeaseMax, b.System()) + resp, err := f(req, d) + if err != nil { + return nil, err + } + + // Make sure we increase the VALID UNTIL endpoint for this user. + if expireTime := resp.Secret.ExpirationTime(); !expireTime.IsZero() { + expiration := expireTime.Format("2006-01-02 15:04:05-0700") + + err := db.RenewUser(username, expiration) + if err != nil { + return nil, err + } + } + + return resp, nil +} + +func (b *databaseBackend) secretCredsRevoke(req *logical.Request, d *framework.FieldData) (*logical.Response, error) { + // Get the username from the internal data + usernameRaw, ok := req.Secret.InternalData["username"] + if !ok { + return nil, fmt.Errorf("secret is missing username internal data") + } + username, ok := usernameRaw.(string) + + var revocationSQL string + var resp *logical.Response + + roleNameRaw, ok := req.Secret.InternalData["role"] + if !ok { + return nil, fmt.Errorf("Could not find role with name: %s", req.Secret.InternalData["role"]) + } + + role, err := b.Role(req.Storage, roleNameRaw.(string)) + if err != nil { + return nil, err + } + if role == nil { + return nil, fmt.Errorf("Could not find role with name: %s", req.Secret.InternalData["role"]) + } + + /* TODO: think about how to handle this case. + if !ok { + role, err := b.Role(req.Storage, roleNameRaw.(string)) + if err != nil { + return nil, err + } + if role == nil { + if resp == nil { + resp = &logical.Response{} + } + resp.AddWarning(fmt.Sprintf("Role %q cannot be found. Using default revocation SQL.", roleNameRaw.(string))) + } else { + revocationSQL = role.RevocationStatement + } + }*/ + + // Grab the read lock + b.RLock() + defer b.RUnlock() + + // Get our connection + db, ok := b.connections[role.DBName] + if !ok { + return nil, fmt.Errorf("Could not find database with name: %s", role.DBName) + } + + // TODO: Maybe move this down into db package? + switch revocationSQL { + + // This is the default revocation logic. If revocation SQL is provided it + // is simply executed as-is. + case "": + err := db.DefaultRevokeUser(username) + if err != nil { + return nil, err + } + + // We have revocation SQL, execute directly, within a transaction + default: + err := db.CustomRevokeUser(username, revocationSQL) + if err != nil { + return nil, err + } + } + + return resp, nil +} diff --git a/cli/commands.go b/cli/commands.go index 190111177953..13f7c8b25aad 100644 --- a/cli/commands.go +++ b/cli/commands.go @@ -21,6 +21,7 @@ import ( "github.com/hashicorp/vault/builtin/logical/aws" "github.com/hashicorp/vault/builtin/logical/cassandra" "github.com/hashicorp/vault/builtin/logical/consul" + "github.com/hashicorp/vault/builtin/logical/database" "github.com/hashicorp/vault/builtin/logical/mongodb" "github.com/hashicorp/vault/builtin/logical/mssql" "github.com/hashicorp/vault/builtin/logical/mysql" @@ -91,6 +92,7 @@ func Commands(metaPtr *meta.Meta) map[string]cli.CommandFactory { "mysql": mysql.Factory, "ssh": ssh.Factory, "rabbitmq": rabbitmq.Factory, + "database": database.Factory, }, ShutdownCh: command.MakeShutdownCh(), SighupCh: command.MakeSighupCh(), From ad17d113c7babe60effb8b5cb0992e905e8c708d Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Tue, 20 Dec 2016 11:46:20 -0800 Subject: [PATCH 002/152] More work on refactor and cassandra database --- builtin/logical/database/backend.go | 19 -- builtin/logical/database/dbs/cassandra.go | 194 ++++--------- .../database/dbs/connectionproducer.go | 254 ++++++++++++++++++ .../database/dbs/credentialsproducer.go | 79 ++++++ builtin/logical/database/dbs/db.go | 67 +++-- builtin/logical/database/dbs/postgresql.go | 102 ++----- .../database/path_config_connection.go | 10 +- builtin/logical/database/path_config_lease.go | 103 ------- builtin/logical/database/path_role_create.go | 52 +--- builtin/logical/database/path_roles.go | 50 +++- builtin/logical/database/secret_creds.go | 47 ++-- 11 files changed, 548 insertions(+), 429 deletions(-) create mode 100644 builtin/logical/database/dbs/connectionproducer.go create mode 100644 builtin/logical/database/dbs/credentialsproducer.go delete mode 100644 builtin/logical/database/path_config_lease.go diff --git a/builtin/logical/database/backend.go b/builtin/logical/database/backend.go index 8b7fa36700c5..3d757df1dcc7 100644 --- a/builtin/logical/database/backend.go +++ b/builtin/logical/database/backend.go @@ -22,7 +22,6 @@ func Backend(conf *logical.BackendConfig) *databaseBackend { Paths: []*framework.Path{ pathConfigConnection(&b), - pathConfigLease(&b), pathListRoles(&b), pathRoles(&b), pathRoleCreate(&b), @@ -61,24 +60,6 @@ func (b *databaseBackend) resetAllDBs() { } } -// Lease returns the lease information -func (b *databaseBackend) Lease(s logical.Storage) (*configLease, error) { - entry, err := s.Get("config/lease") - if err != nil { - return nil, err - } - if entry == nil { - return nil, nil - } - - var result configLease - if err := entry.DecodeJSON(&result); err != nil { - return nil, err - } - - return &result, nil -} - func (b *databaseBackend) Role(s logical.Storage, n string) (*roleEntry, error) { entry, err := s.Get("role/" + n) if err != nil { diff --git a/builtin/logical/database/dbs/cassandra.go b/builtin/logical/database/dbs/cassandra.go index 8c7a068becfd..a8889032f93f 100644 --- a/builtin/logical/database/dbs/cassandra.go +++ b/builtin/logical/database/dbs/cassandra.go @@ -1,25 +1,20 @@ package dbs import ( - "crypto/tls" - "database/sql" "fmt" "strings" "sync" - "time" "github.com/gocql/gocql" - "github.com/hashicorp/vault/helper/certutil" - "github.com/hashicorp/vault/helper/tlsutil" + "github.com/hashicorp/vault/helper/strutil" ) type Cassandra struct { // Session is goroutine safe, however, since we reinitialize // it when connection info changes, we want to make sure we // can close it and use a new connection; hence the lock - session *gocql.Session - config ConnectionConfig - + ConnectionProducer + CredentialsProducer sync.RWMutex } @@ -27,168 +22,85 @@ func (c *Cassandra) Type() string { return cassandraTypeName } -func (c *Cassandra) Connection() (*gocql.Session, error) { - // Grab the write lock - c.Lock() - defer c.Unlock() - - // If we already have a DB, we got it! - if c.session != nil { - return c.session, nil - } - - session, err := createSession(c.config) +func (c *Cassandra) getConnection() (*gocql.Session, error) { + session, err := c.Connection() if err != nil { return nil, err } - // Store the session in backend for reuse - c.session = session - - return session, nil -} - -func (p *Cassandra) Close() { - // Grab the write lock - p.Lock() - defer p.Unlock() - - if p.session != nil { - p.session.Close() - } - - p.session = nil + return session.(*gocql.Session), nil } -func (p *Cassandra) Reset(config ConnectionConfig) (*sql.DB, error) { - // Grab the write lock - p.Lock() - p.config = config - p.Unlock() - - p.Close() - return p.Connection() -} - -func (p *Cassandra) CreateUser(createStmt, username, password, expiration string) error { +func (c *Cassandra) CreateUser(createStmt, rollbackStmt, username, password, expiration string) error { // Get the connection - db, err := p.Connection() + session, err := c.getConnection() if err != nil { return err } // TODO: This is racey // Grab a read lock - p.RLock() - defer p.RUnlock() + c.RLock() + defer c.RUnlock() + + // Set consistency + /* if .Consistency != "" { + consistencyValue, err := gocql.ParseConsistencyWrapper(role.Consistency) + if err != nil { + return err + } - return nil -} + session.SetConsistency(consistencyValue) + }*/ -func (p *Cassandra) RenewUser(username, expiration string) error { - db, err := p.Connection() - if err != nil { - return err - } - // TODO: This is Racey - // Grab the read lock - p.RLock() - defer p.RUnlock() + // Execute each query + for _, query := range strutil.ParseArbitraryStringSlice(createStmt, ";") { + query = strings.TrimSpace(query) + if len(query) == 0 { + continue + } - return nil -} + err = session.Query(queryHelper(query, map[string]string{ + "username": username, + "password": password, + })).Exec() + if err != nil { + for _, query := range strutil.ParseArbitraryStringSlice(rollbackStmt, ";") { + query = strings.TrimSpace(query) + if len(query) == 0 { + continue + } -func (p *Cassandra) CustomRevokeUser(username, revocationSQL string) error { - db, err := p.Connection() - if err != nil { - return err + session.Query(queryHelper(query, map[string]string{ + "username": username, + "password": password, + })).Exec() + } + return err + } } - // TODO: this is Racey - p.RLock() - defer p.RUnlock() return nil } -func (p *Cassandra) DefaultRevokeUser(username string) error { - // Grab the read lock - p.RLock() - defer p.RUnlock() - - db, err := p.Connection() - +func (c *Cassandra) RenewUser(username, expiration string) error { + // NOOP return nil } -func createSession(cfg *ConnectionConfig) (*gocql.Session, error) { - clusterConfig := gocql.NewCluster(strings.Split(cfg.Hosts, ",")...) - clusterConfig.Authenticator = gocql.PasswordAuthenticator{ - Username: cfg.Username, - Password: cfg.Password, - } - - clusterConfig.ProtoVersion = cfg.ProtocolVersion - if clusterConfig.ProtoVersion == 0 { - clusterConfig.ProtoVersion = 2 - } - - clusterConfig.Timeout = time.Duration(cfg.ConnectTimeout) * time.Second - - if cfg.TLS { - var tlsConfig *tls.Config - if len(cfg.Certificate) > 0 || len(cfg.IssuingCA) > 0 { - if len(cfg.Certificate) > 0 && len(cfg.PrivateKey) == 0 { - return nil, fmt.Errorf("Found certificate for TLS authentication but no private key") - } - - certBundle := &certutil.CertBundle{} - if len(cfg.Certificate) > 0 { - certBundle.Certificate = cfg.Certificate - certBundle.PrivateKey = cfg.PrivateKey - } - if len(cfg.IssuingCA) > 0 { - certBundle.IssuingCA = cfg.IssuingCA - } - - parsedCertBundle, err := certBundle.ToParsedCertBundle() - if err != nil { - return nil, fmt.Errorf("failed to parse certificate bundle: %s", err) - } - - tlsConfig, err = parsedCertBundle.GetTLSConfig(certutil.TLSClient) - if err != nil || tlsConfig == nil { - return nil, fmt.Errorf("failed to get TLS configuration: tlsConfig:%#v err:%v", tlsConfig, err) - } - tlsConfig.InsecureSkipVerify = cfg.InsecureTLS - - if cfg.TLSMinVersion != "" { - var ok bool - tlsConfig.MinVersion, ok = tlsutil.TLSLookup[cfg.TLSMinVersion] - if !ok { - return nil, fmt.Errorf("invalid 'tls_min_version' in config") - } - } else { - // MinVersion was not being set earlier. Reset it to - // zero to gracefully handle upgrades. - tlsConfig.MinVersion = 0 - } - } - - clusterConfig.SslOpts = &gocql.SslOptions{ - Config: *tlsConfig, - } - } - - session, err := clusterConfig.CreateSession() +func (c *Cassandra) RevokeUser(username, revocationSQL string) error { + session, err := c.getConnection() if err != nil { - return nil, fmt.Errorf("Error creating session: %s", err) + return err } + // TODO: this is Racey + c.RLock() + defer c.RUnlock() - // Verify the info - err = session.Query(`LIST USERS`).Exec() + err = session.Query(fmt.Sprintf("DROP USER '%s'", username)).Exec() if err != nil { - return nil, fmt.Errorf("Error validating connection info: %s", err) + return fmt.Errorf("error removing user %s", username) } - return session, nil + return nil } diff --git a/builtin/logical/database/dbs/connectionproducer.go b/builtin/logical/database/dbs/connectionproducer.go new file mode 100644 index 000000000000..adecfd55acdb --- /dev/null +++ b/builtin/logical/database/dbs/connectionproducer.go @@ -0,0 +1,254 @@ +package dbs + +import ( + "crypto/tls" + "database/sql" + "fmt" + "strings" + "sync" + "time" + + "github.com/gocql/gocql" + "github.com/hashicorp/vault/helper/certutil" + "github.com/hashicorp/vault/helper/tlsutil" + "github.com/mitchellh/mapstructure" +) + +type ConnectionProducer interface { + Connection() (interface{}, error) + Close() + // TODO: Should we make this immutable instead? + Reset(*DatabaseConfig) error +} + +// sqlConnectionProducer impliments ConnectionProducer and provides a generic producer for most sql databases +type sqlConnectionDetails struct { + ConnectionURL string `json:"connection_url" structs:"connection_url" mapstructure:"connection_url"` +} + +type sqlConnectionProducer struct { + config *DatabaseConfig + // TODO: Should we merge these two structures make it immutable? + connDetails *sqlConnectionDetails + + db *sql.DB + sync.Mutex +} + +func (cp *sqlConnectionProducer) Connection() (interface{}, error) { + // Grab the write lock + cp.Lock() + defer cp.Unlock() + + // If we already have a DB, we got it! + if cp.db != nil { + if err := cp.db.Ping(); err == nil { + return cp.db, nil + } + // If the ping was unsuccessful, close it and ignore errors as we'll be + // reestablishing anyways + cp.db.Close() + } + + // Otherwise, attempt to make connection + conn := cp.connDetails.ConnectionURL + + // Ensure timezone is set to UTC for all the conenctions + if strings.HasPrefix(conn, "postgres://") || strings.HasPrefix(conn, "postgresql://") { + if strings.Contains(conn, "?") { + conn += "&timezone=utc" + } else { + conn += "?timezone=utc" + } + } else { + conn += " timezone=utc" + } + + var err error + cp.db, err = sql.Open(cp.config.DatabaseType, conn) + if err != nil { + return nil, err + } + + // Set some connection pool settings. We don't need much of this, + // since the request rate shouldn't be high. + cp.db.SetMaxOpenConns(cp.config.MaxOpenConnections) + cp.db.SetMaxIdleConns(cp.config.MaxIdleConnections) + + return cp.db, nil +} + +func (cp *sqlConnectionProducer) Close() { + // Grab the write lock + cp.Lock() + defer cp.Unlock() + + if cp.db != nil { + cp.db.Close() + } + + cp.db = nil +} + +func (cp *sqlConnectionProducer) Reset(config *DatabaseConfig) error { + // Grab the write lock + cp.Lock() + + var details *sqlConnectionDetails + err := mapstructure.Decode(config.ConnectionDetails, &details) + if err != nil { + return err + } + + cp.connDetails = details + cp.config = config + + cp.Unlock() + + cp.Close() + _, err = cp.Connection() + return err +} + +// cassandraConnectionProducer impliments ConnectionProducer and provides connections for cassandra +type cassandraConnectionDetails struct { + Hosts string `json:"hosts" structs:"hosts" mapstructure:"hosts"` + Username string `json:"username" structs:"username" mapstructure:"username"` + Password string `json:"password" structs:"password" mapstructure:"password"` + TLS bool `json:"tls" structs:"tls" mapstructure:"tls"` + InsecureTLS bool `json:"insecure_tls" structs:"insecure_tls" mapstructure:"insecure_tls"` + Certificate string `json:"certificate" structs:"certificate" mapstructure:"certificate"` + PrivateKey string `json:"private_key" structs:"private_key" mapstructure:"private_key"` + IssuingCA string `json:"issuing_ca" structs:"issuing_ca" mapstructure:"issuing_ca"` + ProtocolVersion int `json:"protocol_version" structs:"protocol_version" mapstructure:"protocol_version"` + ConnectTimeout int `json:"connect_timeout" structs:"connect_timeout" mapstructure:"connect_timeout"` + TLSMinVersion string `json:"tls_min_version" structs:"tls_min_version" mapstructure:"tls_min_version"` + Consistancy string `json:"tls_min_version" structs:"tls_min_version" mapstructure:"tls_min_version"` +} + +type cassandraConnectionProducer struct { + config *DatabaseConfig + // TODO: Should we merge these two structures make it immutable? + connDetails *cassandraConnectionDetails + + session *gocql.Session + sync.Mutex +} + +func (cp *cassandraConnectionProducer) Connection() (interface{}, error) { + // Grab the write lock + cp.Lock() + defer cp.Unlock() + + // If we already have a DB, we got it! + if cp.session != nil { + return cp.session, nil + } + + session, err := cp.createSession(cp.connDetails) + if err != nil { + return nil, err + } + + // Store the session in backend for reuse + cp.session = session + + return session, nil +} + +func (cp *cassandraConnectionProducer) Close() { + // Grab the write lock + cp.Lock() + defer cp.Unlock() + + if cp.session != nil { + cp.session.Close() + } + + cp.session = nil +} + +func (cp *cassandraConnectionProducer) Reset(config *DatabaseConfig) error { + // Grab the write lock + cp.Lock() + cp.config = config + cp.Unlock() + + cp.Close() + _, err := cp.Connection() + + return err +} + +func (cp *cassandraConnectionProducer) createSession(cfg *cassandraConnectionDetails) (*gocql.Session, error) { + clusterConfig := gocql.NewCluster(strings.Split(cfg.Hosts, ",")...) + clusterConfig.Authenticator = gocql.PasswordAuthenticator{ + Username: cfg.Username, + Password: cfg.Password, + } + + clusterConfig.ProtoVersion = cfg.ProtocolVersion + if clusterConfig.ProtoVersion == 0 { + clusterConfig.ProtoVersion = 2 + } + + clusterConfig.Timeout = time.Duration(cfg.ConnectTimeout) * time.Second + + if cfg.TLS { + var tlsConfig *tls.Config + if len(cfg.Certificate) > 0 || len(cfg.IssuingCA) > 0 { + if len(cfg.Certificate) > 0 && len(cfg.PrivateKey) == 0 { + return nil, fmt.Errorf("Found certificate for TLS authentication but no private key") + } + + certBundle := &certutil.CertBundle{} + if len(cfg.Certificate) > 0 { + certBundle.Certificate = cfg.Certificate + certBundle.PrivateKey = cfg.PrivateKey + } + if len(cfg.IssuingCA) > 0 { + certBundle.IssuingCA = cfg.IssuingCA + } + + parsedCertBundle, err := certBundle.ToParsedCertBundle() + if err != nil { + return nil, fmt.Errorf("failed to parse certificate bundle: %s", err) + } + + tlsConfig, err = parsedCertBundle.GetTLSConfig(certutil.TLSClient) + if err != nil || tlsConfig == nil { + return nil, fmt.Errorf("failed to get TLS configuration: tlsConfig:%#v err:%v", tlsConfig, err) + } + tlsConfig.InsecureSkipVerify = cfg.InsecureTLS + + if cfg.TLSMinVersion != "" { + var ok bool + tlsConfig.MinVersion, ok = tlsutil.TLSLookup[cfg.TLSMinVersion] + if !ok { + return nil, fmt.Errorf("invalid 'tls_min_version' in config") + } + } else { + // MinVersion was not being set earlier. Reset it to + // zero to gracefully handle upgrades. + tlsConfig.MinVersion = 0 + } + } + + clusterConfig.SslOpts = &gocql.SslOptions{ + Config: *tlsConfig, + } + } + + session, err := clusterConfig.CreateSession() + if err != nil { + return nil, fmt.Errorf("Error creating session: %s", err) + } + + // Verify the info + err = session.Query(`LIST USERS`).Exec() + if err != nil { + return nil, fmt.Errorf("Error validating connection info: %s", err) + } + + return session, nil +} diff --git a/builtin/logical/database/dbs/credentialsproducer.go b/builtin/logical/database/dbs/credentialsproducer.go new file mode 100644 index 000000000000..20210c2e85a4 --- /dev/null +++ b/builtin/logical/database/dbs/credentialsproducer.go @@ -0,0 +1,79 @@ +package dbs + +import ( + "fmt" + "strings" + "time" + + uuid "github.com/hashicorp/go-uuid" +) + +type CredentialsProducer interface { + GenerateUsername(displayName string) (string, error) + GeneratePassword() (string, error) + GenerateExpiration(ttl time.Duration) string +} + +// sqlCredentialsProducer impliments CredentialsProducer and provides a generic credentials producer for most sql database types. +type sqlCredentialsProducer struct { + displayNameLen int + usernameLen int +} + +func (scg *sqlCredentialsProducer) GenerateUsername(displayName string) (string, error) { + // Generate the username, password and expiration. PG limits user to 63 characters + if scg.displayNameLen > 0 && len(displayName) > scg.displayNameLen { + displayName = displayName[:scg.displayNameLen] + } + userUUID, err := uuid.GenerateUUID() + if err != nil { + return "", err + } + username := fmt.Sprintf("%s-%s", displayName, userUUID) + if scg.usernameLen > 0 && len(username) > scg.usernameLen { + username = username[:scg.usernameLen] + } + + return username, nil +} + +func (scg *sqlCredentialsProducer) GeneratePassword() (string, error) { + password, err := uuid.GenerateUUID() + if err != nil { + return "", err + } + + return password, nil +} + +func (scg *sqlCredentialsProducer) GenerateExpiration(ttl time.Duration) string { + return time.Now(). + Add(ttl). + Format("2006-01-02 15:04:05-0700") +} + +type cassandraCredentialsProducer struct{} + +func (ccp *cassandraCredentialsProducer) GenerateUsername(displayName string) (string, error) { + userUUID, err := uuid.GenerateUUID() + if err != nil { + return "", err + } + username := fmt.Sprintf("vault_%s_%s_%d", displayName, userUUID, time.Now().Unix()) + username = strings.Replace(username, "-", "_", -1) + + return username, nil +} + +func (ccp *cassandraCredentialsProducer) GeneratePassword() (string, error) { + password, err := uuid.GenerateUUID() + if err != nil { + return "", err + } + + return password, nil +} + +func (ccp *cassandraCredentialsProducer) GenerateExpiration(ttl time.Duration) string { + return "" +} diff --git a/builtin/logical/database/dbs/db.go b/builtin/logical/database/dbs/db.go index ee7b15b64d40..9d261ff42bfd 100644 --- a/builtin/logical/database/dbs/db.go +++ b/builtin/logical/database/dbs/db.go @@ -1,10 +1,11 @@ package dbs import ( - "database/sql" "errors" "fmt" "strings" + + "github.com/mitchellh/mapstructure" ) const ( @@ -16,11 +17,47 @@ var ( ErrUnsupportedDatabaseType = errors.New("Unsupported database type") ) -func Factory(conf ConnectionConfig) (DatabaseType, error) { - switch conf.ConnectionType { +func Factory(conf *DatabaseConfig) (DatabaseType, error) { + switch conf.DatabaseType { case postgreSQLTypeName: + var details *sqlConnectionDetails + err := mapstructure.Decode(conf.ConnectionDetails, &details) + if err != nil { + return nil, err + } + + connProducer := &sqlConnectionProducer{ + config: conf, + connDetails: details, + } + + credsProducer := &sqlCredentialsProducer{ + displayNameLen: 23, + usernameLen: 63, + } + return &PostgreSQL{ - config: conf, + ConnectionProducer: connProducer, + CredentialsProducer: credsProducer, + }, nil + + case cassandraTypeName: + var details *cassandraConnectionDetails + err := mapstructure.Decode(conf.ConnectionDetails, &details) + if err != nil { + return nil, err + } + + connProducer := &cassandraConnectionProducer{ + config: conf, + connDetails: details, + } + + credsProducer := &cassandraCredentialsProducer{} + + return &Cassandra{ + ConnectionProducer: connProducer, + CredentialsProducer: credsProducer, }, nil } @@ -29,21 +66,19 @@ func Factory(conf ConnectionConfig) (DatabaseType, error) { type DatabaseType interface { Type() string - Connection() (*sql.DB, error) - Close() - Reset(ConnectionConfig) (*sql.DB, error) - CreateUser(createStmt, username, password, expiration string) error + CreateUser(createStmt, rollbackStmt, username, password, expiration string) error RenewUser(username, expiration string) error - CustomRevokeUser(username, revocationSQL string) error - DefaultRevokeUser(username string) error + RevokeUser(username, revocationStmt string) error + + ConnectionProducer + CredentialsProducer } -type ConnectionConfig struct { - ConnectionType string `json:"type" structs:"type" mapstructure:"type"` - ConnectionURL string `json:"connection_url" structs:"connection_url" mapstructure:"connection_url"` - ConnectionDetails map[string]string `json:"connection_details" structs:"connection_details" mapstructure:"connection_details"` - MaxOpenConnections int `json:"max_open_connections" structs:"max_open_connections" mapstructure:"max_open_connections"` - MaxIdleConnections int `json:"max_idle_connections" structs:"max_idle_connections" mapstructure:"max_idle_connections"` +type DatabaseConfig struct { + DatabaseType string `json:"type" structs:"type" mapstructure:"type"` + ConnectionDetails map[string]interface{} `json:"connection_details" structs:"connection_details" mapstructure:"connection_details"` + MaxOpenConnections int `json:"max_open_connections" structs:"max_open_connections" mapstructure:"max_open_connections"` + MaxIdleConnections int `json:"max_idle_connections" structs:"max_idle_connections" mapstructure:"max_idle_connections"` } // Query templates a query for us. diff --git a/builtin/logical/database/dbs/postgresql.go b/builtin/logical/database/dbs/postgresql.go index ea7d08f8ac78..e050e30bf537 100644 --- a/builtin/logical/database/dbs/postgresql.go +++ b/builtin/logical/database/dbs/postgresql.go @@ -11,9 +11,10 @@ import ( ) type PostgreSQL struct { - db *sql.DB - config ConnectionConfig + db *sql.DB + ConnectionProducer + CredentialsProducer sync.RWMutex } @@ -21,74 +22,18 @@ func (p *PostgreSQL) Type() string { return postgreSQLTypeName } -func (p *PostgreSQL) Connection() (*sql.DB, error) { - // Grab the write lock - p.Lock() - defer p.Unlock() - - // If we already have a DB, we got it! - if p.db != nil { - if err := p.db.Ping(); err == nil { - return p.db, nil - } - // If the ping was unsuccessful, close it and ignore errors as we'll be - // reestablishing anyways - p.db.Close() - } - - // Otherwise, attempt to make connection - conn := p.config.ConnectionURL - - // Ensure timezone is set to UTC for all the conenctions - if strings.HasPrefix(conn, "postgres://") || strings.HasPrefix(conn, "postgresql://") { - if strings.Contains(conn, "?") { - conn += "&timezone=utc" - } else { - conn += "?timezone=utc" - } - } else { - conn += " timezone=utc" - } - - var err error - p.db, err = sql.Open("postgres", conn) +func (p *PostgreSQL) getConnection() (*sql.DB, error) { + db, err := p.Connection() if err != nil { return nil, err } - // Set some connection pool settings. We don't need much of this, - // since the request rate shouldn't be high. - p.db.SetMaxOpenConns(p.config.MaxOpenConnections) - p.db.SetMaxIdleConns(p.config.MaxIdleConnections) - - return p.db, nil -} - -func (p *PostgreSQL) Close() { - // Grab the write lock - p.Lock() - defer p.Unlock() - - if p.db != nil { - p.db.Close() - } - - p.db = nil -} - -func (p *PostgreSQL) Reset(config ConnectionConfig) (*sql.DB, error) { - // Grab the write lock - p.Lock() - p.config = config - p.Unlock() - - p.Close() - return p.Connection() + return db.(*sql.DB), nil } -func (p *PostgreSQL) CreateUser(createStmt, username, password, expiration string) error { +func (p *PostgreSQL) CreateUser(createStmt, rollbackStmt, username, password, expiration string) error { // Get the connection - db, err := p.Connection() + db, err := p.getConnection() if err != nil { return err } @@ -144,7 +89,7 @@ func (p *PostgreSQL) CreateUser(createStmt, username, password, expiration strin } func (p *PostgreSQL) RenewUser(username, expiration string) error { - db, err := p.Connection() + db, err := p.getConnection() if err != nil { return err } @@ -170,14 +115,23 @@ func (p *PostgreSQL) RenewUser(username, expiration string) error { return nil } -func (p *PostgreSQL) CustomRevokeUser(username, revocationSQL string) error { - db, err := p.Connection() +func (p *PostgreSQL) RevokeUser(username, revocationStmt string) error { + // Grab the read lock + p.RLock() + defer p.RUnlock() + + if revocationStmt == "" { + return p.defaultRevokeUser(username) + } + + return p.customRevokeUser(username, revocationStmt) +} + +func (p *PostgreSQL) customRevokeUser(username, revocationStmt string) error { + db, err := p.getConnection() if err != nil { return err } - // TODO: this is Racey - p.RLock() - defer p.RUnlock() tx, err := db.Begin() if err != nil { @@ -187,7 +141,7 @@ func (p *PostgreSQL) CustomRevokeUser(username, revocationSQL string) error { tx.Rollback() }() - for _, query := range strutil.ParseArbitraryStringSlice(revocationSQL, ";") { + for _, query := range strutil.ParseArbitraryStringSlice(revocationStmt, ";") { query = strings.TrimSpace(query) if len(query) == 0 { continue @@ -213,12 +167,8 @@ func (p *PostgreSQL) CustomRevokeUser(username, revocationSQL string) error { return nil } -func (p *PostgreSQL) DefaultRevokeUser(username string) error { - // Grab the read lock - p.RLock() - defer p.RUnlock() - - db, err := p.Connection() +func (p *PostgreSQL) defaultRevokeUser(username string) error { + db, err := p.getConnection() if err != nil { return err } diff --git a/builtin/logical/database/path_config_connection.go b/builtin/logical/database/path_config_connection.go index be017ea35c02..d4a969a69743 100644 --- a/builtin/logical/database/path_config_connection.go +++ b/builtin/logical/database/path_config_connection.go @@ -79,7 +79,7 @@ func (b *databaseBackend) pathConnectionRead(req *logical.Request, data *framewo return nil, nil } - var config dbs.ConnectionConfig + var config dbs.DatabaseConfig if err := entry.DecodeJSON(&config); err != nil { return nil, err } @@ -89,8 +89,8 @@ func (b *databaseBackend) pathConnectionRead(req *logical.Request, data *framewo } func (b *databaseBackend) pathConnectionWrite(req *logical.Request, data *framework.FieldData) (*logical.Response, error) { - connURL := data.Get("connection_url").(string) connType := data.Get("connection_type").(string) + connDetails := data.Get("connection_details").(map[string]interface{}) maxOpenConns := data.Get("max_open_connections").(int) if maxOpenConns == 0 { @@ -105,9 +105,9 @@ func (b *databaseBackend) pathConnectionWrite(req *logical.Request, data *framew maxIdleConns = maxOpenConns } - config := dbs.ConnectionConfig{ - ConnectionType: connType, - ConnectionURL: connURL, + config := &dbs.DatabaseConfig{ + DatabaseType: connType, + ConnectionDetails: connDetails, MaxOpenConnections: maxOpenConns, MaxIdleConnections: maxIdleConns, } diff --git a/builtin/logical/database/path_config_lease.go b/builtin/logical/database/path_config_lease.go deleted file mode 100644 index 5cc40a056e9d..000000000000 --- a/builtin/logical/database/path_config_lease.go +++ /dev/null @@ -1,103 +0,0 @@ -package database - -import ( - "fmt" - "time" - - "github.com/hashicorp/vault/logical" - "github.com/hashicorp/vault/logical/framework" -) - -func pathConfigLease(b *databaseBackend) *framework.Path { - return &framework.Path{ - Pattern: "config/lease", - Fields: map[string]*framework.FieldSchema{ - "lease": &framework.FieldSchema{ - Type: framework.TypeString, - Description: "Default lease for roles.", - }, - - "lease_max": &framework.FieldSchema{ - Type: framework.TypeString, - Description: "Maximum time a credential is valid for.", - }, - }, - - Callbacks: map[logical.Operation]framework.OperationFunc{ - logical.ReadOperation: b.pathLeaseRead, - logical.UpdateOperation: b.pathLeaseWrite, - }, - - HelpSynopsis: pathConfigLeaseHelpSyn, - HelpDescription: pathConfigLeaseHelpDesc, - } -} - -func (b *databaseBackend) pathLeaseWrite( - req *logical.Request, d *framework.FieldData) (*logical.Response, error) { - leaseRaw := d.Get("lease").(string) - leaseMaxRaw := d.Get("lease_max").(string) - - lease, err := time.ParseDuration(leaseRaw) - if err != nil { - return logical.ErrorResponse(fmt.Sprintf( - "Invalid lease: %s", err)), nil - } - leaseMax, err := time.ParseDuration(leaseMaxRaw) - if err != nil { - return logical.ErrorResponse(fmt.Sprintf( - "Invalid lease: %s", err)), nil - } - - // Store it - entry, err := logical.StorageEntryJSON("config/lease", &configLease{ - Lease: lease, - LeaseMax: leaseMax, - }) - if err != nil { - return nil, err - } - if err := req.Storage.Put(entry); err != nil { - return nil, err - } - - return nil, nil -} - -func (b *databaseBackend) pathLeaseRead( - req *logical.Request, data *framework.FieldData) (*logical.Response, error) { - lease, err := b.Lease(req.Storage) - - if err != nil { - return nil, err - } - if lease == nil { - return nil, nil - } - - return &logical.Response{ - Data: map[string]interface{}{ - "lease": lease.Lease.String(), - "lease_max": lease.LeaseMax.String(), - }, - }, nil -} - -type configLease struct { - Lease time.Duration - LeaseMax time.Duration -} - -const pathConfigLeaseHelpSyn = ` -Configure the default lease information for generated credentials. -` - -const pathConfigLeaseHelpDesc = ` -This configures the default lease information used for credentials -generated by this backend. The lease specifies the duration that a -credential will be valid for, as well as the maximum session for -a set of credentials. - -The format for the lease is "1h" or integer and then unit. The longest -unit is hour. -` diff --git a/builtin/logical/database/path_role_create.go b/builtin/logical/database/path_role_create.go index 2a2386d01213..15ca915bae36 100644 --- a/builtin/logical/database/path_role_create.go +++ b/builtin/logical/database/path_role_create.go @@ -2,9 +2,7 @@ package database import ( "fmt" - "time" - "github.com/hashicorp/go-uuid" "github.com/hashicorp/vault/logical" "github.com/hashicorp/vault/logical/framework" _ "github.com/lib/pq" @@ -45,41 +43,7 @@ func (b *databaseBackend) pathRoleCreateRead(req *logical.Request, data *framewo return logical.ErrorResponse(fmt.Sprintf("unknown role: %s", name)), nil } - // Determine if we have a lease - b.logger.Trace("postgres/pathRoleCreateRead: getting lease") - lease, err := b.Lease(req.Storage) - if err != nil { - return nil, err - } - // Unlike some other backends we need a lease here (can't leave as 0 and - // let core fill it in) because Postgres also expires users as a safety - // measure, so cannot be zero - if lease == nil { - lease = &configLease{ - Lease: b.System().DefaultLeaseTTL(), - } - } - // Generate the username, password and expiration. PG limits user to 63 characters - displayName := req.DisplayName - if len(displayName) > 26 { - displayName = displayName[:26] - } - userUUID, err := uuid.GenerateUUID() - if err != nil { - return nil, err - } - username := fmt.Sprintf("%s-%s", displayName, userUUID) - if len(username) > 63 { - username = username[:63] - } - password, err := uuid.GenerateUUID() - if err != nil { - return nil, err - } - expiration := time.Now(). - Add(lease.Lease). - Format("2006-01-02 15:04:05-0700") // Get our handle b.logger.Trace("postgres/pathRoleCreateRead: getting database handle") @@ -92,7 +56,19 @@ func (b *databaseBackend) pathRoleCreateRead(req *logical.Request, data *framewo return nil, fmt.Errorf("Cound not find DB with name: %s", role.DBName) } - err = db.CreateUser(role.CreationStatement, username, password, expiration) + username, err := db.GenerateUsername(req.DisplayName) + if err != nil { + return nil, err + } + + password, err := db.GeneratePassword() + if err != nil { + return nil, err + } + + expiration := db.GenerateExpiration(role.DefaultTTL) + + err = db.CreateUser(role.CreationStatement, role.RollbackStatement, username, password, expiration) if err != nil { return nil, err } @@ -105,7 +81,7 @@ func (b *databaseBackend) pathRoleCreateRead(req *logical.Request, data *framewo "username": username, "role": name, }) - resp.Secret.TTL = lease.Lease + resp.Secret.TTL = role.DefaultTTL return resp, nil } diff --git a/builtin/logical/database/path_roles.go b/builtin/logical/database/path_roles.go index e06518b289a4..dc8c6805ab88 100644 --- a/builtin/logical/database/path_roles.go +++ b/builtin/logical/database/path_roles.go @@ -1,6 +1,9 @@ package database import ( + "fmt" + "time" + "github.com/hashicorp/vault/logical" "github.com/hashicorp/vault/logical/framework" ) @@ -44,6 +47,24 @@ func pathRoles(b *databaseBackend) *framework.Path { array, or a base64-encoded serialized JSON string array. The '{{name}}' value will be substituted.`, }, + + "rollback_statement": { + Type: framework.TypeString, + Description: `SQL statements to be executed to revoke a user. Must be a semicolon-separated + string, a base64-encoded semicolon-separated string, a serialized JSON string + array, or a base64-encoded serialized JSON string array. The '{{name}}' value + will be substituted.`, + }, + + "default_ttl": { + Type: framework.TypeString, + Description: "Default ttl for role.", + }, + + "max_ttl": { + Type: framework.TypeString, + Description: "Maximum time a credential is valid for", + }, }, Callbacks: map[logical.Operation]framework.OperationFunc{ @@ -79,6 +100,9 @@ func (b *databaseBackend) pathRoleRead(req *logical.Request, data *framework.Fie Data: map[string]interface{}{ "creation_statment": role.CreationStatement, "revocation_statement": role.RevocationStatement, + "rollback_statement": role.RollbackStatement, + "default_ttl": role.DefaultTTL.String(), + "max_ttl": role.MaxTTL.String(), }, }, nil } @@ -97,6 +121,20 @@ func (b *databaseBackend) pathRoleCreate(req *logical.Request, data *framework.F dbName := data.Get("db_name").(string) creationStmt := data.Get("creation_statement").(string) revocationStmt := data.Get("revocation_statement").(string) + rollbackStmt := data.Get("rollback_statement").(string) + defaultTTLRaw := data.Get("default_ttl").(string) + maxTTLRaw := data.Get("max_ttl").(string) + + defaultTTL, err := time.ParseDuration(defaultTTLRaw) + if err != nil { + return logical.ErrorResponse(fmt.Sprintf( + "Invalid default_ttl: %s", err)), nil + } + maxTTL, err := time.ParseDuration(maxTTLRaw) + if err != nil { + return logical.ErrorResponse(fmt.Sprintf( + "Invalid max_ttl: %s", err)), nil + } // TODO: Think about preparing the statments to test. @@ -105,6 +143,9 @@ func (b *databaseBackend) pathRoleCreate(req *logical.Request, data *framework.F DBName: dbName, CreationStatement: creationStmt, RevocationStatement: revocationStmt, + RollbackStatement: rollbackStmt, + DefaultTTL: defaultTTL, + MaxTTL: maxTTL, }) if err != nil { return nil, err @@ -117,9 +158,12 @@ func (b *databaseBackend) pathRoleCreate(req *logical.Request, data *framework.F } type roleEntry struct { - DBName string `json:"db_name" mapstructure:"db_name" structs:"db_name"` - CreationStatement string `json:"creation_statment" mapstructure:"creation_statement" structs:"creation_statment"` - RevocationStatement string `json:"revocation_statement" mapstructure:"revocation_statement" structs:"revocation_statement"` + DBName string `json:"db_name" mapstructure:"db_name" structs:"db_name"` + CreationStatement string `json:"creation_statment" mapstructure:"creation_statement" structs:"creation_statment"` + RevocationStatement string `json:"revocation_statement" mapstructure:"revocation_statement" structs:"revocation_statement"` + RollbackStatement string `json:"revocation_statement" mapstructure:"revocation_statement" structs:"revocation_statement"` + DefaultTTL time.Duration `json:"default_ttl" mapstructure:"default_ttl" structs:"default_ttl"` + MaxTTL time.Duration `json:"max_ttl" mapstructure:"max_ttl" structs:"max_ttl"` } const pathRoleHelpSyn = ` diff --git a/builtin/logical/database/secret_creds.go b/builtin/logical/database/secret_creds.go index 30c4a6430f22..120804e91ce3 100644 --- a/builtin/logical/database/secret_creds.go +++ b/builtin/logical/database/secret_creds.go @@ -1,7 +1,6 @@ package database import ( - "errors" "fmt" "github.com/hashicorp/vault/logical" @@ -31,8 +30,6 @@ func secretCreds(b *databaseBackend) *framework.Secret { } func (b *databaseBackend) secretCredsRenew(req *logical.Request, d *framework.FieldData) (*logical.Response, error) { - dbName := d.Get("name").(string) - // Get the username from the internal data usernameRaw, ok := req.Secret.InternalData["username"] if !ok { @@ -40,27 +37,35 @@ func (b *databaseBackend) secretCredsRenew(req *logical.Request, d *framework.Fi } username, ok := usernameRaw.(string) - // Get our connection - db, ok := b.connections[dbName] + roleNameRaw, ok := req.Secret.InternalData["role"] if !ok { - return nil, errors.New(fmt.Sprintf("Could not find connection with name %s", dbName)) + return nil, fmt.Errorf("Could not find role with name: %s", req.Secret.InternalData["role"]) } - // Get the lease information - lease, err := b.Lease(req.Storage) + role, err := b.Role(req.Storage, roleNameRaw.(string)) if err != nil { return nil, err } - if lease == nil { - lease = &configLease{} + if role == nil { + return nil, fmt.Errorf("Could not find role with name: %s", req.Secret.InternalData["role"]) } - f := framework.LeaseExtend(lease.Lease, lease.LeaseMax, b.System()) + f := framework.LeaseExtend(role.DefaultTTL, role.MaxTTL, b.System()) resp, err := f(req, d) if err != nil { return nil, err } + // Grab the read lock + b.RLock() + defer b.RUnlock() + + // Get our connection + db, ok := b.connections[role.DBName] + if !ok { + return nil, fmt.Errorf("Could not find connection with name %s", role.DBName) + } + // Make sure we increase the VALID UNTIL endpoint for this user. if expireTime := resp.Secret.ExpirationTime(); !expireTime.IsZero() { expiration := expireTime.Format("2006-01-02 15:04:05-0700") @@ -124,23 +129,9 @@ func (b *databaseBackend) secretCredsRevoke(req *logical.Request, d *framework.F return nil, fmt.Errorf("Could not find database with name: %s", role.DBName) } - // TODO: Maybe move this down into db package? - switch revocationSQL { - - // This is the default revocation logic. If revocation SQL is provided it - // is simply executed as-is. - case "": - err := db.DefaultRevokeUser(username) - if err != nil { - return nil, err - } - - // We have revocation SQL, execute directly, within a transaction - default: - err := db.CustomRevokeUser(username, revocationSQL) - if err != nil { - return nil, err - } + err = db.RevokeUser(username, revocationSQL) + if err != nil { + return nil, err } return resp, nil From bfbb104e19498afe8aef2c1acf06f48f6cc99775 Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Wed, 4 Jan 2017 10:18:10 -0800 Subject: [PATCH 003/152] Add mysql database type --- .../database/dbs/connectionproducer.go | 1 + .../database/dbs/credentialsproducer.go | 14 +- builtin/logical/database/dbs/mysql.go | 136 ++++++++++++++++++ .../database/path_config_connection.go | 2 +- 4 files changed, 145 insertions(+), 8 deletions(-) create mode 100644 builtin/logical/database/dbs/mysql.go diff --git a/builtin/logical/database/dbs/connectionproducer.go b/builtin/logical/database/dbs/connectionproducer.go index adecfd55acdb..dc8f6c82c1bf 100644 --- a/builtin/logical/database/dbs/connectionproducer.go +++ b/builtin/logical/database/dbs/connectionproducer.go @@ -8,6 +8,7 @@ import ( "sync" "time" + _ "github.com/go-sql-driver/mysql" "github.com/gocql/gocql" "github.com/hashicorp/vault/helper/certutil" "github.com/hashicorp/vault/helper/tlsutil" diff --git a/builtin/logical/database/dbs/credentialsproducer.go b/builtin/logical/database/dbs/credentialsproducer.go index 20210c2e85a4..94fce6275a2e 100644 --- a/builtin/logical/database/dbs/credentialsproducer.go +++ b/builtin/logical/database/dbs/credentialsproducer.go @@ -20,24 +20,24 @@ type sqlCredentialsProducer struct { usernameLen int } -func (scg *sqlCredentialsProducer) GenerateUsername(displayName string) (string, error) { +func (scp *sqlCredentialsProducer) GenerateUsername(displayName string) (string, error) { // Generate the username, password and expiration. PG limits user to 63 characters - if scg.displayNameLen > 0 && len(displayName) > scg.displayNameLen { - displayName = displayName[:scg.displayNameLen] + if scp.displayNameLen > 0 && len(displayName) > scp.displayNameLen { + displayName = displayName[:scp.displayNameLen] } userUUID, err := uuid.GenerateUUID() if err != nil { return "", err } username := fmt.Sprintf("%s-%s", displayName, userUUID) - if scg.usernameLen > 0 && len(username) > scg.usernameLen { - username = username[:scg.usernameLen] + if scp.usernameLen > 0 && len(username) > scp.usernameLen { + username = username[:scp.usernameLen] } return username, nil } -func (scg *sqlCredentialsProducer) GeneratePassword() (string, error) { +func (scp *sqlCredentialsProducer) GeneratePassword() (string, error) { password, err := uuid.GenerateUUID() if err != nil { return "", err @@ -46,7 +46,7 @@ func (scg *sqlCredentialsProducer) GeneratePassword() (string, error) { return password, nil } -func (scg *sqlCredentialsProducer) GenerateExpiration(ttl time.Duration) string { +func (scp *sqlCredentialsProducer) GenerateExpiration(ttl time.Duration) string { return time.Now(). Add(ttl). Format("2006-01-02 15:04:05-0700") diff --git a/builtin/logical/database/dbs/mysql.go b/builtin/logical/database/dbs/mysql.go new file mode 100644 index 000000000000..314d4c929329 --- /dev/null +++ b/builtin/logical/database/dbs/mysql.go @@ -0,0 +1,136 @@ +package dbs + +import ( + "database/sql" + "strings" + "sync" + + "github.com/hashicorp/vault/helper/strutil" +) + +const defaultRevocationSQL = ` + REVOKE ALL PRIVILEGES, GRANT OPTION FROM '{{name}}'@'%'; + DROP USER '{{name}}'@'%' +` + +type MySQL struct { + db *sql.DB + + ConnectionProducer + CredentialsProducer + sync.RWMutex +} + +func (p *MySQL) Type() string { + return postgreSQLTypeName +} + +func (p *MySQL) getConnection() (*sql.DB, error) { + db, err := p.Connection() + if err != nil { + return nil, err + } + + return db.(*sql.DB), nil +} + +func (p *MySQL) CreateUser(createStmt, rollbackStmt, username, password, expiration string) error { + // Get the connection + db, err := p.getConnection() + if err != nil { + return err + } + + // TODO: This is racey + // Grab a read lock + p.RLock() + defer p.RUnlock() + + // Start a transaction + tx, err := db.Begin() + if err != nil { + return err + } + defer tx.Rollback() + + // Execute each query + for _, query := range strutil.ParseArbitraryStringSlice(createStmt, ";") { + query = strings.TrimSpace(query) + if len(query) == 0 { + continue + } + + stmt, err := tx.Prepare(queryHelper(query, map[string]string{ + "name": username, + "password": password, + })) + if err != nil { + return err + } + defer stmt.Close() + if _, err := stmt.Exec(); err != nil { + return err + } + } + + // Commit the transaction + if err := tx.Commit(); err != nil { + return err + } + + return nil +} + +// NOOP +func (p *MySQL) RenewUser(username, expiration string) error { + return nil +} + +func (p *MySQL) RevokeUser(username, revocationStmt string) error { + // Get the connection + db, err := p.getConnection() + if err != nil { + return err + } + + // Grab the read lock + p.RLock() + defer p.RUnlock() + + // Use a default SQL statement for revocation if one cannot be fetched from the role + + if revocationStmt == "" { + revocationStmt = defaultRevocationSQL + } + + // Start a transaction + tx, err := db.Begin() + if err != nil { + return err + } + defer tx.Rollback() + + for _, query := range strutil.ParseArbitraryStringSlice(revocationStmt, ";") { + query = strings.TrimSpace(query) + if len(query) == 0 { + continue + } + + // This is not a prepared statement because not all commands are supported + // 1295: This command is not supported in the prepared statement protocol yet + // Reference https://mariadb.com/kb/en/mariadb/prepare-statement/ + query = strings.Replace(query, "{{name}}", username, -1) + _, err = tx.Exec(query) + if err != nil { + return err + } + + } + + // Commit the transaction + if err := tx.Commit(); err != nil { + return err + } + + return nil +} diff --git a/builtin/logical/database/path_config_connection.go b/builtin/logical/database/path_config_connection.go index d4a969a69743..90dfea4cddfb 100644 --- a/builtin/logical/database/path_config_connection.go +++ b/builtin/logical/database/path_config_connection.go @@ -124,7 +124,7 @@ func (b *databaseBackend) pathConnectionWrite(req *logical.Request, data *framew // Don't allow the connection type to change if b.connections[name].Type() != connType { - return logical.ErrorResponse("can not change type of existing connection"), nil + return logical.ErrorResponse("Can not change type of existing connection."), nil } db = b.connections[name] From cee3dc9b9e880b119a9135b1f3a10e55f0164666 Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Wed, 4 Jan 2017 10:53:39 -0800 Subject: [PATCH 004/152] s/Statement/Statements/ --- builtin/logical/database/dbs/cassandra.go | 13 ++++-- .../database/dbs/connectionproducer.go | 3 ++ builtin/logical/database/dbs/mysql.go | 14 +++--- builtin/logical/database/dbs/postgresql.go | 14 +++--- builtin/logical/database/path_role_create.go | 3 +- builtin/logical/database/path_roles.go | 46 +++++++++---------- 6 files changed, 50 insertions(+), 43 deletions(-) diff --git a/builtin/logical/database/dbs/cassandra.go b/builtin/logical/database/dbs/cassandra.go index a8889032f93f..7a06e131483d 100644 --- a/builtin/logical/database/dbs/cassandra.go +++ b/builtin/logical/database/dbs/cassandra.go @@ -9,6 +9,11 @@ import ( "github.com/hashicorp/vault/helper/strutil" ) +const ( + defaultCreationCQL = `CREATE USER '{{username}}' WITH PASSWORD '{{password}}' NOSUPERUSER;` + defaultRollbackCQL = `DROP USER '{{username}}';` +) + type Cassandra struct { // Session is goroutine safe, however, since we reinitialize // it when connection info changes, we want to make sure we @@ -31,7 +36,7 @@ func (c *Cassandra) getConnection() (*gocql.Session, error) { return session.(*gocql.Session), nil } -func (c *Cassandra) CreateUser(createStmt, rollbackStmt, username, password, expiration string) error { +func (c *Cassandra) CreateUser(createStmts, rollbackStmts, username, password, expiration string) error { // Get the connection session, err := c.getConnection() if err != nil { @@ -54,7 +59,7 @@ func (c *Cassandra) CreateUser(createStmt, rollbackStmt, username, password, exp }*/ // Execute each query - for _, query := range strutil.ParseArbitraryStringSlice(createStmt, ";") { + for _, query := range strutil.ParseArbitraryStringSlice(createStmts, ";") { query = strings.TrimSpace(query) if len(query) == 0 { continue @@ -65,7 +70,7 @@ func (c *Cassandra) CreateUser(createStmt, rollbackStmt, username, password, exp "password": password, })).Exec() if err != nil { - for _, query := range strutil.ParseArbitraryStringSlice(rollbackStmt, ";") { + for _, query := range strutil.ParseArbitraryStringSlice(rollbackStmts, ";") { query = strings.TrimSpace(query) if len(query) == 0 { continue @@ -88,7 +93,7 @@ func (c *Cassandra) RenewUser(username, expiration string) error { return nil } -func (c *Cassandra) RevokeUser(username, revocationSQL string) error { +func (c *Cassandra) RevokeUser(username, revocationStmts string) error { session, err := c.getConnection() if err != nil { return err diff --git a/builtin/logical/database/dbs/connectionproducer.go b/builtin/logical/database/dbs/connectionproducer.go index dc8f6c82c1bf..5c606996d36e 100644 --- a/builtin/logical/database/dbs/connectionproducer.go +++ b/builtin/logical/database/dbs/connectionproducer.go @@ -8,7 +8,10 @@ import ( "sync" "time" + // Import sql drivers _ "github.com/go-sql-driver/mysql" + _ "github.com/lib/pq" + "github.com/gocql/gocql" "github.com/hashicorp/vault/helper/certutil" "github.com/hashicorp/vault/helper/tlsutil" diff --git a/builtin/logical/database/dbs/mysql.go b/builtin/logical/database/dbs/mysql.go index 314d4c929329..0a18683eaae2 100644 --- a/builtin/logical/database/dbs/mysql.go +++ b/builtin/logical/database/dbs/mysql.go @@ -8,7 +8,7 @@ import ( "github.com/hashicorp/vault/helper/strutil" ) -const defaultRevocationSQL = ` +const defaultRevocationStmts = ` REVOKE ALL PRIVILEGES, GRANT OPTION FROM '{{name}}'@'%'; DROP USER '{{name}}'@'%' ` @@ -34,7 +34,7 @@ func (p *MySQL) getConnection() (*sql.DB, error) { return db.(*sql.DB), nil } -func (p *MySQL) CreateUser(createStmt, rollbackStmt, username, password, expiration string) error { +func (p *MySQL) CreateUser(createStmts, rollbackStmts, username, password, expiration string) error { // Get the connection db, err := p.getConnection() if err != nil { @@ -54,7 +54,7 @@ func (p *MySQL) CreateUser(createStmt, rollbackStmt, username, password, expirat defer tx.Rollback() // Execute each query - for _, query := range strutil.ParseArbitraryStringSlice(createStmt, ";") { + for _, query := range strutil.ParseArbitraryStringSlice(createStmts, ";") { query = strings.TrimSpace(query) if len(query) == 0 { continue @@ -86,7 +86,7 @@ func (p *MySQL) RenewUser(username, expiration string) error { return nil } -func (p *MySQL) RevokeUser(username, revocationStmt string) error { +func (p *MySQL) RevokeUser(username, revocationStmts string) error { // Get the connection db, err := p.getConnection() if err != nil { @@ -99,8 +99,8 @@ func (p *MySQL) RevokeUser(username, revocationStmt string) error { // Use a default SQL statement for revocation if one cannot be fetched from the role - if revocationStmt == "" { - revocationStmt = defaultRevocationSQL + if revocationStmts == "" { + revocationStmts = defaultRevocationStmts } // Start a transaction @@ -110,7 +110,7 @@ func (p *MySQL) RevokeUser(username, revocationStmt string) error { } defer tx.Rollback() - for _, query := range strutil.ParseArbitraryStringSlice(revocationStmt, ";") { + for _, query := range strutil.ParseArbitraryStringSlice(revocationStmts, ";") { query = strings.TrimSpace(query) if len(query) == 0 { continue diff --git a/builtin/logical/database/dbs/postgresql.go b/builtin/logical/database/dbs/postgresql.go index e050e30bf537..01fb3cd708e7 100644 --- a/builtin/logical/database/dbs/postgresql.go +++ b/builtin/logical/database/dbs/postgresql.go @@ -31,7 +31,7 @@ func (p *PostgreSQL) getConnection() (*sql.DB, error) { return db.(*sql.DB), nil } -func (p *PostgreSQL) CreateUser(createStmt, rollbackStmt, username, password, expiration string) error { +func (p *PostgreSQL) CreateUser(createStmts, rollbackStmts, username, password, expiration string) error { // Get the connection db, err := p.getConnection() if err != nil { @@ -56,7 +56,7 @@ func (p *PostgreSQL) CreateUser(createStmt, rollbackStmt, username, password, ex // Return the secret // Execute each query - for _, query := range strutil.ParseArbitraryStringSlice(createStmt, ";") { + for _, query := range strutil.ParseArbitraryStringSlice(createStmts, ";") { query = strings.TrimSpace(query) if len(query) == 0 { continue @@ -115,19 +115,19 @@ func (p *PostgreSQL) RenewUser(username, expiration string) error { return nil } -func (p *PostgreSQL) RevokeUser(username, revocationStmt string) error { +func (p *PostgreSQL) RevokeUser(username, revocationStmts string) error { // Grab the read lock p.RLock() defer p.RUnlock() - if revocationStmt == "" { + if revocationStmts == "" { return p.defaultRevokeUser(username) } - return p.customRevokeUser(username, revocationStmt) + return p.customRevokeUser(username, revocationStmts) } -func (p *PostgreSQL) customRevokeUser(username, revocationStmt string) error { +func (p *PostgreSQL) customRevokeUser(username, revocationStmts string) error { db, err := p.getConnection() if err != nil { return err @@ -141,7 +141,7 @@ func (p *PostgreSQL) customRevokeUser(username, revocationStmt string) error { tx.Rollback() }() - for _, query := range strutil.ParseArbitraryStringSlice(revocationStmt, ";") { + for _, query := range strutil.ParseArbitraryStringSlice(revocationStmts, ";") { query = strings.TrimSpace(query) if len(query) == 0 { continue diff --git a/builtin/logical/database/path_role_create.go b/builtin/logical/database/path_role_create.go index 15ca915bae36..b1cce97f30e9 100644 --- a/builtin/logical/database/path_role_create.go +++ b/builtin/logical/database/path_role_create.go @@ -5,7 +5,6 @@ import ( "github.com/hashicorp/vault/logical" "github.com/hashicorp/vault/logical/framework" - _ "github.com/lib/pq" ) func pathRoleCreate(b *databaseBackend) *framework.Path { @@ -68,7 +67,7 @@ func (b *databaseBackend) pathRoleCreateRead(req *logical.Request, data *framewo expiration := db.GenerateExpiration(role.DefaultTTL) - err = db.CreateUser(role.CreationStatement, role.RollbackStatement, username, password, expiration) + err = db.CreateUser(role.CreationStatements, role.RollbackStatements, username, password, expiration) if err != nil { return nil, err } diff --git a/builtin/logical/database/path_roles.go b/builtin/logical/database/path_roles.go index dc8c6805ab88..994d084f0901 100644 --- a/builtin/logical/database/path_roles.go +++ b/builtin/logical/database/path_roles.go @@ -35,12 +35,12 @@ func pathRoles(b *databaseBackend) *framework.Path { Description: "Name of the database this role acts on.", }, - "creation_statement": { + "creation_statements": { Type: framework.TypeString, Description: "SQL string to create a user. See help for more info.", }, - "revocation_statement": { + "revocation_statements": { Type: framework.TypeString, Description: `SQL statements to be executed to revoke a user. Must be a semicolon-separated string, a base64-encoded semicolon-separated string, a serialized JSON string @@ -48,7 +48,7 @@ func pathRoles(b *databaseBackend) *framework.Path { will be substituted.`, }, - "rollback_statement": { + "rollback_statements": { Type: framework.TypeString, Description: `SQL statements to be executed to revoke a user. Must be a semicolon-separated string, a base64-encoded semicolon-separated string, a serialized JSON string @@ -98,11 +98,11 @@ func (b *databaseBackend) pathRoleRead(req *logical.Request, data *framework.Fie return &logical.Response{ Data: map[string]interface{}{ - "creation_statment": role.CreationStatement, - "revocation_statement": role.RevocationStatement, - "rollback_statement": role.RollbackStatement, - "default_ttl": role.DefaultTTL.String(), - "max_ttl": role.MaxTTL.String(), + "creation_statments": role.CreationStatements, + "revocation_statements": role.RevocationStatements, + "rollback_statements": role.RollbackStatements, + "default_ttl": role.DefaultTTL.String(), + "max_ttl": role.MaxTTL.String(), }, }, nil } @@ -119,9 +119,9 @@ func (b *databaseBackend) pathRoleList(req *logical.Request, d *framework.FieldD func (b *databaseBackend) pathRoleCreate(req *logical.Request, data *framework.FieldData) (*logical.Response, error) { name := data.Get("name").(string) dbName := data.Get("db_name").(string) - creationStmt := data.Get("creation_statement").(string) - revocationStmt := data.Get("revocation_statement").(string) - rollbackStmt := data.Get("rollback_statement").(string) + creationStmts := data.Get("creation_statements").(string) + revocationStmts := data.Get("revocation_statements").(string) + rollbackStmts := data.Get("rollback_statements").(string) defaultTTLRaw := data.Get("default_ttl").(string) maxTTLRaw := data.Get("max_ttl").(string) @@ -140,12 +140,12 @@ func (b *databaseBackend) pathRoleCreate(req *logical.Request, data *framework.F // Store it entry, err := logical.StorageEntryJSON("role/"+name, &roleEntry{ - DBName: dbName, - CreationStatement: creationStmt, - RevocationStatement: revocationStmt, - RollbackStatement: rollbackStmt, - DefaultTTL: defaultTTL, - MaxTTL: maxTTL, + DBName: dbName, + CreationStatements: creationStmts, + RevocationStatements: revocationStmts, + RollbackStatements: rollbackStmts, + DefaultTTL: defaultTTL, + MaxTTL: maxTTL, }) if err != nil { return nil, err @@ -158,12 +158,12 @@ func (b *databaseBackend) pathRoleCreate(req *logical.Request, data *framework.F } type roleEntry struct { - DBName string `json:"db_name" mapstructure:"db_name" structs:"db_name"` - CreationStatement string `json:"creation_statment" mapstructure:"creation_statement" structs:"creation_statment"` - RevocationStatement string `json:"revocation_statement" mapstructure:"revocation_statement" structs:"revocation_statement"` - RollbackStatement string `json:"revocation_statement" mapstructure:"revocation_statement" structs:"revocation_statement"` - DefaultTTL time.Duration `json:"default_ttl" mapstructure:"default_ttl" structs:"default_ttl"` - MaxTTL time.Duration `json:"max_ttl" mapstructure:"max_ttl" structs:"max_ttl"` + DBName string `json:"db_name" mapstructure:"db_name" structs:"db_name"` + CreationStatements string `json:"creation_statment" mapstructure:"creation_statement" structs:"creation_statment"` + RevocationStatements string `json:"revocation_statement" mapstructure:"revocation_statement" structs:"revocation_statement"` + RollbackStatements string `json:"revocation_statement" mapstructure:"revocation_statement" structs:"revocation_statement"` + DefaultTTL time.Duration `json:"default_ttl" mapstructure:"default_ttl" structs:"default_ttl"` + MaxTTL time.Duration `json:"max_ttl" mapstructure:"max_ttl" structs:"max_ttl"` } const pathRoleHelpSyn = ` From 5e2cffcdd05f0d5cab60145c496327b9faadac37 Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Wed, 4 Jan 2017 11:28:30 -0800 Subject: [PATCH 005/152] Add max connection lifetime param and set consistancy on cassandra session --- .../database/dbs/connectionproducer.go | 13 +++++++- builtin/logical/database/dbs/db.go | 10 ++++--- .../database/path_config_connection.go | 30 ++++++++++++------- 3 files changed, 37 insertions(+), 16 deletions(-) diff --git a/builtin/logical/database/dbs/connectionproducer.go b/builtin/logical/database/dbs/connectionproducer.go index 5c606996d36e..e1a7ae9bbdb7 100644 --- a/builtin/logical/database/dbs/connectionproducer.go +++ b/builtin/logical/database/dbs/connectionproducer.go @@ -78,6 +78,7 @@ func (cp *sqlConnectionProducer) Connection() (interface{}, error) { // since the request rate shouldn't be high. cp.db.SetMaxOpenConns(cp.config.MaxOpenConnections) cp.db.SetMaxIdleConns(cp.config.MaxIdleConnections) + cp.db.SetConnMaxLifetime(cp.config.MaxConnectionLifetime) return cp.db, nil } @@ -127,7 +128,7 @@ type cassandraConnectionDetails struct { ProtocolVersion int `json:"protocol_version" structs:"protocol_version" mapstructure:"protocol_version"` ConnectTimeout int `json:"connect_timeout" structs:"connect_timeout" mapstructure:"connect_timeout"` TLSMinVersion string `json:"tls_min_version" structs:"tls_min_version" mapstructure:"tls_min_version"` - Consistancy string `json:"tls_min_version" structs:"tls_min_version" mapstructure:"tls_min_version"` + Consistency string `json:"consistency" structs:"consistency" mapstructure:"consistency"` } type cassandraConnectionProducer struct { @@ -248,6 +249,16 @@ func (cp *cassandraConnectionProducer) createSession(cfg *cassandraConnectionDet return nil, fmt.Errorf("Error creating session: %s", err) } + // Set consistency + if cfg.Consistency != "" { + consistencyValue, err := gocql.ParseConsistencyWrapper(cfg.Consistency) + if err != nil { + return nil, err + } + + session.SetConsistency(consistencyValue) + } + // Verify the info err = session.Query(`LIST USERS`).Exec() if err != nil { diff --git a/builtin/logical/database/dbs/db.go b/builtin/logical/database/dbs/db.go index 9d261ff42bfd..e901f69f8c13 100644 --- a/builtin/logical/database/dbs/db.go +++ b/builtin/logical/database/dbs/db.go @@ -4,6 +4,7 @@ import ( "errors" "fmt" "strings" + "time" "github.com/mitchellh/mapstructure" ) @@ -75,10 +76,11 @@ type DatabaseType interface { } type DatabaseConfig struct { - DatabaseType string `json:"type" structs:"type" mapstructure:"type"` - ConnectionDetails map[string]interface{} `json:"connection_details" structs:"connection_details" mapstructure:"connection_details"` - MaxOpenConnections int `json:"max_open_connections" structs:"max_open_connections" mapstructure:"max_open_connections"` - MaxIdleConnections int `json:"max_idle_connections" structs:"max_idle_connections" mapstructure:"max_idle_connections"` + DatabaseType string `json:"type" structs:"type" mapstructure:"type"` + ConnectionDetails map[string]interface{} `json:"connection_details" structs:"connection_details" mapstructure:"connection_details"` + MaxOpenConnections int `json:"max_open_connections" structs:"max_open_connections" mapstructure:"max_open_connections"` + MaxIdleConnections int `json:"max_idle_connections" structs:"max_idle_connections" mapstructure:"max_idle_connections"` + MaxConnectionLifetime time.Duration `json:"max_connection_lifetime" structs:"max_connection_lifetime" mapstructure:"max_connection_lifetime"` } // Query templates a query for us. diff --git a/builtin/logical/database/path_config_connection.go b/builtin/logical/database/path_config_connection.go index 90dfea4cddfb..06cf1dd4ca1d 100644 --- a/builtin/logical/database/path_config_connection.go +++ b/builtin/logical/database/path_config_connection.go @@ -2,12 +2,12 @@ package database import ( "fmt" + "time" "github.com/fatih/structs" "github.com/hashicorp/vault/builtin/logical/database/dbs" "github.com/hashicorp/vault/logical" "github.com/hashicorp/vault/logical/framework" - _ "github.com/lib/pq" ) func pathConfigConnection(b *databaseBackend) *framework.Path { @@ -24,11 +24,6 @@ func pathConfigConnection(b *databaseBackend) *framework.Path { Description: "DB type (e.g. postgres)", }, - "connection_url": &framework.FieldSchema{ - Type: framework.TypeString, - Description: "DB connection string", - }, - "connection_details": &framework.FieldSchema{ Type: framework.TypeMap, Description: "Connection details for specified connection type.", @@ -55,6 +50,12 @@ and a negative value disables idle connections. If larger than max_open_connections it will be reduced to the same size.`, }, + + "max_connection_lifetime": &framework.FieldSchema{ + Type: framework.TypeInt, + Description: `Maximum amount of time a connection may be reused; + a zero or negative value reuses connections forever.`, + }, }, Callbacks: map[logical.Operation]framework.OperationFunc{ @@ -105,11 +106,19 @@ func (b *databaseBackend) pathConnectionWrite(req *logical.Request, data *framew maxIdleConns = maxOpenConns } + maxConnLifetimeRaw := data.Get("max_connection_lifetime").(string) + maxConnLifetime, err := time.ParseDuration(maxConnLifetimeRaw) + if err != nil { + return logical.ErrorResponse(fmt.Sprintf( + "Invalid max_connection_lifetime: %s", err)), nil + } + config := &dbs.DatabaseConfig{ - DatabaseType: connType, - ConnectionDetails: connDetails, - MaxOpenConnections: maxOpenConns, - MaxIdleConnections: maxIdleConns, + DatabaseType: connType, + ConnectionDetails: connDetails, + MaxOpenConnections: maxOpenConns, + MaxIdleConnections: maxIdleConns, + MaxConnectionLifetime: maxConnLifetime, } name := data.Get("name").(string) @@ -118,7 +127,6 @@ func (b *databaseBackend) pathConnectionWrite(req *logical.Request, data *framew b.Lock() defer b.Unlock() - var err error var db dbs.DatabaseType if _, ok := b.connections[name]; ok { From e442917e268f470a16b42065b1e52061a26d3d4c Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Tue, 7 Feb 2017 17:32:08 -0800 Subject: [PATCH 006/152] Add mysql into the factory --- builtin/logical/database/dbs/db.go | 23 +++++++++++++++++++ builtin/logical/database/dbs/mysql.go | 2 +- .../database/path_config_connection.go | 6 ++--- 3 files changed, 27 insertions(+), 4 deletions(-) diff --git a/builtin/logical/database/dbs/db.go b/builtin/logical/database/dbs/db.go index e901f69f8c13..d648b776fa3f 100644 --- a/builtin/logical/database/dbs/db.go +++ b/builtin/logical/database/dbs/db.go @@ -11,6 +11,7 @@ import ( const ( postgreSQLTypeName = "postgres" + mySQLTypeName = "mysql" cassandraTypeName = "cassandra" ) @@ -42,6 +43,28 @@ func Factory(conf *DatabaseConfig) (DatabaseType, error) { CredentialsProducer: credsProducer, }, nil + case mySQLTypeName: + var details *sqlConnectionDetails + err := mapstructure.Decode(conf.ConnectionDetails, &details) + if err != nil { + return nil, err + } + + connProducer := &sqlConnectionProducer{ + config: conf, + connDetails: details, + } + + credsProducer := &sqlCredentialsProducer{ + displayNameLen: 4, + usernameLen: 16, + } + + return &MySQL{ + ConnectionProducer: connProducer, + CredentialsProducer: credsProducer, + }, nil + case cassandraTypeName: var details *cassandraConnectionDetails err := mapstructure.Decode(conf.ConnectionDetails, &details) diff --git a/builtin/logical/database/dbs/mysql.go b/builtin/logical/database/dbs/mysql.go index 0a18683eaae2..ce6cdac92b21 100644 --- a/builtin/logical/database/dbs/mysql.go +++ b/builtin/logical/database/dbs/mysql.go @@ -22,7 +22,7 @@ type MySQL struct { } func (p *MySQL) Type() string { - return postgreSQLTypeName + return mySQLTypeName } func (p *MySQL) getConnection() (*sql.DB, error) { diff --git a/builtin/logical/database/path_config_connection.go b/builtin/logical/database/path_config_connection.go index 06cf1dd4ca1d..c3f72b743667 100644 --- a/builtin/logical/database/path_config_connection.go +++ b/builtin/logical/database/path_config_connection.go @@ -52,7 +52,8 @@ reduced to the same size.`, }, "max_connection_lifetime": &framework.FieldSchema{ - Type: framework.TypeInt, + Type: framework.TypeString, + Default: "0s", Description: `Maximum amount of time a connection may be reused; a zero or negative value reuses connections forever.`, }, @@ -91,7 +92,6 @@ func (b *databaseBackend) pathConnectionRead(req *logical.Request, data *framewo func (b *databaseBackend) pathConnectionWrite(req *logical.Request, data *framework.FieldData) (*logical.Response, error) { connType := data.Get("connection_type").(string) - connDetails := data.Get("connection_details").(map[string]interface{}) maxOpenConns := data.Get("max_open_connections").(int) if maxOpenConns == 0 { @@ -115,7 +115,7 @@ func (b *databaseBackend) pathConnectionWrite(req *logical.Request, data *framew config := &dbs.DatabaseConfig{ DatabaseType: connType, - ConnectionDetails: connDetails, + ConnectionDetails: data.Raw, MaxOpenConnections: maxOpenConns, MaxIdleConnections: maxIdleConns, MaxConnectionLifetime: maxConnLifetime, From fa8da4cf91ca26ae29d6bcd5fd32006e2ee01f74 Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Wed, 15 Feb 2017 14:31:15 -0800 Subject: [PATCH 007/152] Fix mysql connections --- builtin/logical/database/dbs/connectionproducer.go | 2 -- builtin/logical/database/path_config_connection.go | 5 ----- 2 files changed, 7 deletions(-) diff --git a/builtin/logical/database/dbs/connectionproducer.go b/builtin/logical/database/dbs/connectionproducer.go index e1a7ae9bbdb7..b53bb0c75732 100644 --- a/builtin/logical/database/dbs/connectionproducer.go +++ b/builtin/logical/database/dbs/connectionproducer.go @@ -64,8 +64,6 @@ func (cp *sqlConnectionProducer) Connection() (interface{}, error) { } else { conn += "?timezone=utc" } - } else { - conn += " timezone=utc" } var err error diff --git a/builtin/logical/database/path_config_connection.go b/builtin/logical/database/path_config_connection.go index c3f72b743667..9fe9260508d0 100644 --- a/builtin/logical/database/path_config_connection.go +++ b/builtin/logical/database/path_config_connection.go @@ -24,11 +24,6 @@ func pathConfigConnection(b *databaseBackend) *framework.Path { Description: "DB type (e.g. postgres)", }, - "connection_details": &framework.FieldSchema{ - Type: framework.TypeMap, - Description: "Connection details for specified connection type.", - }, - "verify_connection": &framework.FieldSchema{ Type: framework.TypeBool, Default: true, From 4d335099de385d44003a68c53e4e915170d186a2 Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Wed, 15 Feb 2017 16:51:59 -0800 Subject: [PATCH 008/152] Make db instances immutable and add a reset path to tear down and create a new database instance with an updated config --- builtin/logical/database/backend.go | 1 + .../database/dbs/connectionproducer.go | 150 +++++++----------- builtin/logical/database/dbs/db.go | 30 ++-- .../database/path_config_connection.go | 66 +++++++- 4 files changed, 124 insertions(+), 123 deletions(-) diff --git a/builtin/logical/database/backend.go b/builtin/logical/database/backend.go index 3d757df1dcc7..fe853d3fb4c5 100644 --- a/builtin/logical/database/backend.go +++ b/builtin/logical/database/backend.go @@ -25,6 +25,7 @@ func Backend(conf *logical.BackendConfig) *databaseBackend { pathListRoles(&b), pathRoles(&b), pathRoleCreate(&b), + pathResetConnection(&b), }, Secrets: []*framework.Secret{ diff --git a/builtin/logical/database/dbs/connectionproducer.go b/builtin/logical/database/dbs/connectionproducer.go index b53bb0c75732..1e66d27f6edc 100644 --- a/builtin/logical/database/dbs/connectionproducer.go +++ b/builtin/logical/database/dbs/connectionproducer.go @@ -15,47 +15,40 @@ import ( "github.com/gocql/gocql" "github.com/hashicorp/vault/helper/certutil" "github.com/hashicorp/vault/helper/tlsutil" - "github.com/mitchellh/mapstructure" ) type ConnectionProducer interface { Connection() (interface{}, error) Close() - // TODO: Should we make this immutable instead? - Reset(*DatabaseConfig) error } // sqlConnectionProducer impliments ConnectionProducer and provides a generic producer for most sql databases -type sqlConnectionDetails struct { +type sqlConnectionProducer struct { ConnectionURL string `json:"connection_url" structs:"connection_url" mapstructure:"connection_url"` -} -type sqlConnectionProducer struct { config *DatabaseConfig - // TODO: Should we merge these two structures make it immutable? - connDetails *sqlConnectionDetails db *sql.DB sync.Mutex } -func (cp *sqlConnectionProducer) Connection() (interface{}, error) { +func (c *sqlConnectionProducer) Connection() (interface{}, error) { // Grab the write lock - cp.Lock() - defer cp.Unlock() + c.Lock() + defer c.Unlock() // If we already have a DB, we got it! - if cp.db != nil { - if err := cp.db.Ping(); err == nil { - return cp.db, nil + if c.db != nil { + if err := c.db.Ping(); err == nil { + return c.db, nil } // If the ping was unsuccessful, close it and ignore errors as we'll be // reestablishing anyways - cp.db.Close() + c.db.Close() } // Otherwise, attempt to make connection - conn := cp.connDetails.ConnectionURL + conn := c.ConnectionURL // Ensure timezone is set to UTC for all the conenctions if strings.HasPrefix(conn, "postgres://") || strings.HasPrefix(conn, "postgresql://") { @@ -67,54 +60,33 @@ func (cp *sqlConnectionProducer) Connection() (interface{}, error) { } var err error - cp.db, err = sql.Open(cp.config.DatabaseType, conn) + c.db, err = sql.Open(c.config.DatabaseType, conn) if err != nil { return nil, err } // Set some connection pool settings. We don't need much of this, // since the request rate shouldn't be high. - cp.db.SetMaxOpenConns(cp.config.MaxOpenConnections) - cp.db.SetMaxIdleConns(cp.config.MaxIdleConnections) - cp.db.SetConnMaxLifetime(cp.config.MaxConnectionLifetime) - - return cp.db, nil -} + c.db.SetMaxOpenConns(c.config.MaxOpenConnections) + c.db.SetMaxIdleConns(c.config.MaxIdleConnections) + c.db.SetConnMaxLifetime(c.config.MaxConnectionLifetime) -func (cp *sqlConnectionProducer) Close() { - // Grab the write lock - cp.Lock() - defer cp.Unlock() - - if cp.db != nil { - cp.db.Close() - } - - cp.db = nil + return c.db, nil } -func (cp *sqlConnectionProducer) Reset(config *DatabaseConfig) error { +func (c *sqlConnectionProducer) Close() { // Grab the write lock - cp.Lock() + c.Lock() + defer c.Unlock() - var details *sqlConnectionDetails - err := mapstructure.Decode(config.ConnectionDetails, &details) - if err != nil { - return err + if c.db != nil { + c.db.Close() } - cp.connDetails = details - cp.config = config - - cp.Unlock() - - cp.Close() - _, err = cp.Connection() - return err + c.db = nil } -// cassandraConnectionProducer impliments ConnectionProducer and provides connections for cassandra -type cassandraConnectionDetails struct { +type cassandraConnectionProducer struct { Hosts string `json:"hosts" structs:"hosts" mapstructure:"hosts"` Username string `json:"username" structs:"username" mapstructure:"username"` Password string `json:"password" structs:"password" mapstructure:"password"` @@ -127,90 +99,74 @@ type cassandraConnectionDetails struct { ConnectTimeout int `json:"connect_timeout" structs:"connect_timeout" mapstructure:"connect_timeout"` TLSMinVersion string `json:"tls_min_version" structs:"tls_min_version" mapstructure:"tls_min_version"` Consistency string `json:"consistency" structs:"consistency" mapstructure:"consistency"` -} -type cassandraConnectionProducer struct { config *DatabaseConfig - // TODO: Should we merge these two structures make it immutable? - connDetails *cassandraConnectionDetails session *gocql.Session sync.Mutex } -func (cp *cassandraConnectionProducer) Connection() (interface{}, error) { +func (c *cassandraConnectionProducer) Connection() (interface{}, error) { // Grab the write lock - cp.Lock() - defer cp.Unlock() + c.Lock() + defer c.Unlock() // If we already have a DB, we got it! - if cp.session != nil { - return cp.session, nil + if c.session != nil { + return c.session, nil } - session, err := cp.createSession(cp.connDetails) + session, err := c.createSession() if err != nil { return nil, err } // Store the session in backend for reuse - cp.session = session + c.session = session return session, nil } -func (cp *cassandraConnectionProducer) Close() { +func (c *cassandraConnectionProducer) Close() { // Grab the write lock - cp.Lock() - defer cp.Unlock() + c.Lock() + defer c.Unlock() - if cp.session != nil { - cp.session.Close() + if c.session != nil { + c.session.Close() } - cp.session = nil -} - -func (cp *cassandraConnectionProducer) Reset(config *DatabaseConfig) error { - // Grab the write lock - cp.Lock() - cp.config = config - cp.Unlock() - - cp.Close() - _, err := cp.Connection() - - return err + c.session = nil } -func (cp *cassandraConnectionProducer) createSession(cfg *cassandraConnectionDetails) (*gocql.Session, error) { - clusterConfig := gocql.NewCluster(strings.Split(cfg.Hosts, ",")...) +func (c *cassandraConnectionProducer) createSession() (*gocql.Session, error) { + clusterConfig := gocql.NewCluster(strings.Split(c.Hosts, ",")...) clusterConfig.Authenticator = gocql.PasswordAuthenticator{ - Username: cfg.Username, - Password: cfg.Password, + Username: c.Username, + Password: c.Password, } - clusterConfig.ProtoVersion = cfg.ProtocolVersion + clusterConfig.ProtoVersion = c.ProtocolVersion if clusterConfig.ProtoVersion == 0 { clusterConfig.ProtoVersion = 2 } - clusterConfig.Timeout = time.Duration(cfg.ConnectTimeout) * time.Second + clusterConfig.Timeout = time.Duration(c.ConnectTimeout) * time.Second - if cfg.TLS { + if c.TLS { var tlsConfig *tls.Config - if len(cfg.Certificate) > 0 || len(cfg.IssuingCA) > 0 { - if len(cfg.Certificate) > 0 && len(cfg.PrivateKey) == 0 { + if len(c.Certificate) > 0 || len(c.IssuingCA) > 0 { + if len(c.Certificate) > 0 && len(c.PrivateKey) == 0 { return nil, fmt.Errorf("Found certificate for TLS authentication but no private key") } certBundle := &certutil.CertBundle{} - if len(cfg.Certificate) > 0 { - certBundle.Certificate = cfg.Certificate - certBundle.PrivateKey = cfg.PrivateKey + if len(c.Certificate) > 0 { + certBundle.Certificate = c.Certificate + certBundle.PrivateKey = c.PrivateKey } - if len(cfg.IssuingCA) > 0 { - certBundle.IssuingCA = cfg.IssuingCA + if len(c.IssuingCA) > 0 { + certBundle.IssuingCA = c.IssuingCA } parsedCertBundle, err := certBundle.ToParsedCertBundle() @@ -222,11 +178,11 @@ func (cp *cassandraConnectionProducer) createSession(cfg *cassandraConnectionDet if err != nil || tlsConfig == nil { return nil, fmt.Errorf("failed to get TLS configuration: tlsConfig:%#v err:%v", tlsConfig, err) } - tlsConfig.InsecureSkipVerify = cfg.InsecureTLS + tlsConfig.InsecureSkipVerify = c.InsecureTLS - if cfg.TLSMinVersion != "" { + if c.TLSMinVersion != "" { var ok bool - tlsConfig.MinVersion, ok = tlsutil.TLSLookup[cfg.TLSMinVersion] + tlsConfig.MinVersion, ok = tlsutil.TLSLookup[c.TLSMinVersion] if !ok { return nil, fmt.Errorf("invalid 'tls_min_version' in config") } @@ -248,8 +204,8 @@ func (cp *cassandraConnectionProducer) createSession(cfg *cassandraConnectionDet } // Set consistency - if cfg.Consistency != "" { - consistencyValue, err := gocql.ParseConsistencyWrapper(cfg.Consistency) + if c.Consistency != "" { + consistencyValue, err := gocql.ParseConsistencyWrapper(c.Consistency) if err != nil { return nil, err } diff --git a/builtin/logical/database/dbs/db.go b/builtin/logical/database/dbs/db.go index d648b776fa3f..4c04c0fd4f9f 100644 --- a/builtin/logical/database/dbs/db.go +++ b/builtin/logical/database/dbs/db.go @@ -22,16 +22,12 @@ var ( func Factory(conf *DatabaseConfig) (DatabaseType, error) { switch conf.DatabaseType { case postgreSQLTypeName: - var details *sqlConnectionDetails - err := mapstructure.Decode(conf.ConnectionDetails, &details) + var connProducer *sqlConnectionProducer + err := mapstructure.Decode(conf.ConnectionDetails, &connProducer) if err != nil { return nil, err } - - connProducer := &sqlConnectionProducer{ - config: conf, - connDetails: details, - } + connProducer.config = conf credsProducer := &sqlCredentialsProducer{ displayNameLen: 23, @@ -44,16 +40,12 @@ func Factory(conf *DatabaseConfig) (DatabaseType, error) { }, nil case mySQLTypeName: - var details *sqlConnectionDetails - err := mapstructure.Decode(conf.ConnectionDetails, &details) + var connProducer *sqlConnectionProducer + err := mapstructure.Decode(conf.ConnectionDetails, &connProducer) if err != nil { return nil, err } - - connProducer := &sqlConnectionProducer{ - config: conf, - connDetails: details, - } + connProducer.config = conf credsProducer := &sqlCredentialsProducer{ displayNameLen: 4, @@ -66,16 +58,12 @@ func Factory(conf *DatabaseConfig) (DatabaseType, error) { }, nil case cassandraTypeName: - var details *cassandraConnectionDetails - err := mapstructure.Decode(conf.ConnectionDetails, &details) + var connProducer *cassandraConnectionProducer + err := mapstructure.Decode(conf.ConnectionDetails, &connProducer) if err != nil { return nil, err } - - connProducer := &cassandraConnectionProducer{ - config: conf, - connDetails: details, - } + connProducer.config = conf credsProducer := &cassandraCredentialsProducer{} diff --git a/builtin/logical/database/path_config_connection.go b/builtin/logical/database/path_config_connection.go index 9fe9260508d0..085113fe98e7 100644 --- a/builtin/logical/database/path_config_connection.go +++ b/builtin/logical/database/path_config_connection.go @@ -1,6 +1,7 @@ package database import ( + "errors" "fmt" "time" @@ -10,6 +11,64 @@ import ( "github.com/hashicorp/vault/logical/framework" ) +func pathResetConnection(b *databaseBackend) *framework.Path { + return &framework.Path{ + Pattern: fmt.Sprintf("reset/%s", framework.GenericNameRegex("name")), + Fields: map[string]*framework.FieldSchema{ + "name": &framework.FieldSchema{ + Type: framework.TypeString, + Description: "Name of this DB type", + }, + }, + + Callbacks: map[logical.Operation]framework.OperationFunc{ + logical.UpdateOperation: b.pathConnectionReset, + }, + + HelpSynopsis: pathConfigConnectionHelpSyn, + HelpDescription: pathConfigConnectionHelpDesc, + } +} + +func (b *databaseBackend) pathConnectionReset(req *logical.Request, data *framework.FieldData) (*logical.Response, error) { + name := data.Get("name").(string) + if name == "" { + return nil, errors.New("No database name set") + } + + // Grab the mutex lock + b.Lock() + defer b.Unlock() + + entry, err := req.Storage.Get(fmt.Sprintf("dbs/%s", name)) + if err != nil { + return nil, fmt.Errorf("failed to read connection configuration") + } + if entry == nil { + return nil, nil + } + + var config dbs.DatabaseConfig + if err := entry.DecodeJSON(&config); err != nil { + return nil, err + } + + db, ok := b.connections[name] + if !ok { + return logical.ErrorResponse("Can not change type of existing connection."), nil + } + + db.Close() + db, err = dbs.Factory(&config) + if err != nil { + return logical.ErrorResponse(fmt.Sprintf("Error creating database object: %s", err)), nil + } + + b.connections[name] = db + + return nil, nil +} + func pathConfigConnection(b *databaseBackend) *framework.Path { return &framework.Path{ Pattern: fmt.Sprintf("dbs/%s", framework.GenericNameRegex("name")), @@ -129,13 +188,13 @@ func (b *databaseBackend) pathConnectionWrite(req *logical.Request, data *framew if b.connections[name].Type() != connType { return logical.ErrorResponse("Can not change type of existing connection."), nil } - - db = b.connections[name] } else { db, err = dbs.Factory(config) if err != nil { return logical.ErrorResponse(fmt.Sprintf("Error creating database object: %s", err)), nil } + + b.connections[name] = db } /* @@ -166,9 +225,6 @@ func (b *databaseBackend) pathConnectionWrite(req *logical.Request, data *framew } // Reset the DB connection - db.Reset(config) - b.connections[name] = db - resp := &logical.Response{} resp.AddWarning("Read access to this endpoint should be controlled via ACLs as it will return the connection string or URL as it is, including passwords, if any.") From 354233f91d400d784cdb73170e9cbec8a77bd349 Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Fri, 3 Mar 2017 15:07:41 -0800 Subject: [PATCH 009/152] rename mysql variable --- builtin/logical/database/dbs/mysql.go | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/builtin/logical/database/dbs/mysql.go b/builtin/logical/database/dbs/mysql.go index ce6cdac92b21..30452ca543c6 100644 --- a/builtin/logical/database/dbs/mysql.go +++ b/builtin/logical/database/dbs/mysql.go @@ -21,12 +21,12 @@ type MySQL struct { sync.RWMutex } -func (p *MySQL) Type() string { +func (m *MySQL) Type() string { return mySQLTypeName } -func (p *MySQL) getConnection() (*sql.DB, error) { - db, err := p.Connection() +func (m *MySQL) getConnection() (*sql.DB, error) { + db, err := m.Connection() if err != nil { return nil, err } @@ -34,17 +34,17 @@ func (p *MySQL) getConnection() (*sql.DB, error) { return db.(*sql.DB), nil } -func (p *MySQL) CreateUser(createStmts, rollbackStmts, username, password, expiration string) error { +func (m *MySQL) CreateUser(createStmts, rollbackStmts, username, password, expiration string) error { // Get the connection - db, err := p.getConnection() + db, err := m.getConnection() if err != nil { return err } // TODO: This is racey // Grab a read lock - p.RLock() - defer p.RUnlock() + m.RLock() + defer m.RUnlock() // Start a transaction tx, err := db.Begin() @@ -82,20 +82,20 @@ func (p *MySQL) CreateUser(createStmts, rollbackStmts, username, password, expir } // NOOP -func (p *MySQL) RenewUser(username, expiration string) error { +func (m *MySQL) RenewUser(username, expiration string) error { return nil } -func (p *MySQL) RevokeUser(username, revocationStmts string) error { +func (m *MySQL) RevokeUser(username, revocationStmts string) error { // Get the connection - db, err := p.getConnection() + db, err := m.getConnection() if err != nil { return err } // Grab the read lock - p.RLock() - defer p.RUnlock() + m.RLock() + defer m.RUnlock() // Use a default SQL statement for revocation if one cannot be fetched from the role From c823ad059744ee2ad0471c4a43121a1a2e4dbc0e Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Tue, 7 Mar 2017 13:48:29 -0800 Subject: [PATCH 010/152] Update locking functionaility --- builtin/logical/database/dbs/cassandra.go | 30 ++++++------------- .../database/dbs/connectionproducer.go | 14 ++++----- builtin/logical/database/dbs/mysql.go | 22 ++++++-------- builtin/logical/database/dbs/postgresql.go | 27 ++++++++--------- 4 files changed, 36 insertions(+), 57 deletions(-) diff --git a/builtin/logical/database/dbs/cassandra.go b/builtin/logical/database/dbs/cassandra.go index 7a06e131483d..9c5607e0d1b9 100644 --- a/builtin/logical/database/dbs/cassandra.go +++ b/builtin/logical/database/dbs/cassandra.go @@ -3,7 +3,6 @@ package dbs import ( "fmt" "strings" - "sync" "github.com/gocql/gocql" "github.com/hashicorp/vault/helper/strutil" @@ -20,7 +19,6 @@ type Cassandra struct { // can close it and use a new connection; hence the lock ConnectionProducer CredentialsProducer - sync.RWMutex } func (c *Cassandra) Type() string { @@ -28,7 +26,7 @@ func (c *Cassandra) Type() string { } func (c *Cassandra) getConnection() (*gocql.Session, error) { - session, err := c.Connection() + session, err := c.connection() if err != nil { return nil, err } @@ -37,27 +35,16 @@ func (c *Cassandra) getConnection() (*gocql.Session, error) { } func (c *Cassandra) CreateUser(createStmts, rollbackStmts, username, password, expiration string) error { + // Grab the lock + c.Lock() + defer c.Unlock() + // Get the connection session, err := c.getConnection() if err != nil { return err } - // TODO: This is racey - // Grab a read lock - c.RLock() - defer c.RUnlock() - - // Set consistency - /* if .Consistency != "" { - consistencyValue, err := gocql.ParseConsistencyWrapper(role.Consistency) - if err != nil { - return err - } - - session.SetConsistency(consistencyValue) - }*/ - // Execute each query for _, query := range strutil.ParseArbitraryStringSlice(createStmts, ";") { query = strings.TrimSpace(query) @@ -94,13 +81,14 @@ func (c *Cassandra) RenewUser(username, expiration string) error { } func (c *Cassandra) RevokeUser(username, revocationStmts string) error { + // Grab the lock + c.Lock() + defer c.Unlock() + session, err := c.getConnection() if err != nil { return err } - // TODO: this is Racey - c.RLock() - defer c.RUnlock() err = session.Query(fmt.Sprintf("DROP USER '%s'", username)).Exec() if err != nil { diff --git a/builtin/logical/database/dbs/connectionproducer.go b/builtin/logical/database/dbs/connectionproducer.go index 1e66d27f6edc..268ab615c47a 100644 --- a/builtin/logical/database/dbs/connectionproducer.go +++ b/builtin/logical/database/dbs/connectionproducer.go @@ -18,8 +18,10 @@ import ( ) type ConnectionProducer interface { - Connection() (interface{}, error) Close() + + sync.Locker + connection() (interface{}, error) } // sqlConnectionProducer impliments ConnectionProducer and provides a generic producer for most sql databases @@ -32,12 +34,8 @@ type sqlConnectionProducer struct { sync.Mutex } -func (c *sqlConnectionProducer) Connection() (interface{}, error) { - // Grab the write lock - c.Lock() - defer c.Unlock() - - // If we already have a DB, we got it! +func (c *sqlConnectionProducer) connection() (interface{}, error) { + // If we already have a DB, test it and return if c.db != nil { if err := c.db.Ping(); err == nil { return c.db, nil @@ -106,7 +104,7 @@ type cassandraConnectionProducer struct { sync.Mutex } -func (c *cassandraConnectionProducer) Connection() (interface{}, error) { +func (c *cassandraConnectionProducer) connection() (interface{}, error) { // Grab the write lock c.Lock() defer c.Unlock() diff --git a/builtin/logical/database/dbs/mysql.go b/builtin/logical/database/dbs/mysql.go index 30452ca543c6..b5574d1a5ae6 100644 --- a/builtin/logical/database/dbs/mysql.go +++ b/builtin/logical/database/dbs/mysql.go @@ -3,7 +3,6 @@ package dbs import ( "database/sql" "strings" - "sync" "github.com/hashicorp/vault/helper/strutil" ) @@ -18,7 +17,6 @@ type MySQL struct { ConnectionProducer CredentialsProducer - sync.RWMutex } func (m *MySQL) Type() string { @@ -26,7 +24,7 @@ func (m *MySQL) Type() string { } func (m *MySQL) getConnection() (*sql.DB, error) { - db, err := m.Connection() + db, err := m.connection() if err != nil { return nil, err } @@ -35,17 +33,16 @@ func (m *MySQL) getConnection() (*sql.DB, error) { } func (m *MySQL) CreateUser(createStmts, rollbackStmts, username, password, expiration string) error { + // Grab the lock + m.Lock() + defer m.Unlock() + // Get the connection db, err := m.getConnection() if err != nil { return err } - // TODO: This is racey - // Grab a read lock - m.RLock() - defer m.RUnlock() - // Start a transaction tx, err := db.Begin() if err != nil { @@ -87,18 +84,17 @@ func (m *MySQL) RenewUser(username, expiration string) error { } func (m *MySQL) RevokeUser(username, revocationStmts string) error { + // Grab the read lock + m.Lock() + defer m.Unlock() + // Get the connection db, err := m.getConnection() if err != nil { return err } - // Grab the read lock - m.RLock() - defer m.RUnlock() - // Use a default SQL statement for revocation if one cannot be fetched from the role - if revocationStmts == "" { revocationStmts = defaultRevocationStmts } diff --git a/builtin/logical/database/dbs/postgresql.go b/builtin/logical/database/dbs/postgresql.go index 01fb3cd708e7..32c049721131 100644 --- a/builtin/logical/database/dbs/postgresql.go +++ b/builtin/logical/database/dbs/postgresql.go @@ -4,7 +4,6 @@ import ( "database/sql" "fmt" "strings" - "sync" "github.com/hashicorp/vault/helper/strutil" "github.com/lib/pq" @@ -15,7 +14,6 @@ type PostgreSQL struct { ConnectionProducer CredentialsProducer - sync.RWMutex } func (p *PostgreSQL) Type() string { @@ -23,7 +21,7 @@ func (p *PostgreSQL) Type() string { } func (p *PostgreSQL) getConnection() (*sql.DB, error) { - db, err := p.Connection() + db, err := p.connection() if err != nil { return nil, err } @@ -32,17 +30,16 @@ func (p *PostgreSQL) getConnection() (*sql.DB, error) { } func (p *PostgreSQL) CreateUser(createStmts, rollbackStmts, username, password, expiration string) error { + // Grab the lock + p.Lock() + defer p.Unlock() + // Get the connection db, err := p.getConnection() if err != nil { return err } - // TODO: This is racey - // Grab a read lock - p.RLock() - defer p.RUnlock() - // Start a transaction // b.logger.Trace("postgres/pathRoleCreateRead: starting transaction") tx, err := db.Begin() @@ -89,14 +86,14 @@ func (p *PostgreSQL) CreateUser(createStmts, rollbackStmts, username, password, } func (p *PostgreSQL) RenewUser(username, expiration string) error { + // Grab the lock + p.Lock() + defer p.Unlock() + db, err := p.getConnection() if err != nil { return err } - // TODO: This is Racey - // Grab the read lock - p.RLock() - defer p.RUnlock() query := fmt.Sprintf( "ALTER ROLE %s VALID UNTIL '%s';", @@ -116,9 +113,9 @@ func (p *PostgreSQL) RenewUser(username, expiration string) error { } func (p *PostgreSQL) RevokeUser(username, revocationStmts string) error { - // Grab the read lock - p.RLock() - defer p.RUnlock() + // Grab the lock + p.Lock() + defer p.Unlock() if revocationStmts == "" { return p.defaultRevokeUser(username) From 1d23bbbe2853a33ac1f1aa99b860d7b23bc4c167 Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Tue, 7 Mar 2017 15:33:05 -0800 Subject: [PATCH 011/152] Remove double lock --- builtin/logical/database/dbs/connectionproducer.go | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/builtin/logical/database/dbs/connectionproducer.go b/builtin/logical/database/dbs/connectionproducer.go index 268ab615c47a..82da37cc7311 100644 --- a/builtin/logical/database/dbs/connectionproducer.go +++ b/builtin/logical/database/dbs/connectionproducer.go @@ -105,11 +105,7 @@ type cassandraConnectionProducer struct { } func (c *cassandraConnectionProducer) connection() (interface{}, error) { - // Grab the write lock - c.Lock() - defer c.Unlock() - - // If we already have a DB, we got it! + // If we already have a DB, return it if c.session != nil { return c.session, nil } From 01300e026b40fb57a7ea838b46f5257df697807f Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Tue, 7 Mar 2017 15:34:23 -0800 Subject: [PATCH 012/152] Remove unused sql object --- builtin/logical/database/dbs/mysql.go | 2 -- builtin/logical/database/dbs/postgresql.go | 2 -- 2 files changed, 4 deletions(-) diff --git a/builtin/logical/database/dbs/mysql.go b/builtin/logical/database/dbs/mysql.go index b5574d1a5ae6..0cf77062cad3 100644 --- a/builtin/logical/database/dbs/mysql.go +++ b/builtin/logical/database/dbs/mysql.go @@ -13,8 +13,6 @@ const defaultRevocationStmts = ` ` type MySQL struct { - db *sql.DB - ConnectionProducer CredentialsProducer } diff --git a/builtin/logical/database/dbs/postgresql.go b/builtin/logical/database/dbs/postgresql.go index 32c049721131..468746fc4bf7 100644 --- a/builtin/logical/database/dbs/postgresql.go +++ b/builtin/logical/database/dbs/postgresql.go @@ -10,8 +10,6 @@ import ( ) type PostgreSQL struct { - db *sql.DB - ConnectionProducer CredentialsProducer } From 78fdc2ad24751b778b5e62b1bed5ee01d27d8b74 Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Tue, 7 Mar 2017 16:48:17 -0800 Subject: [PATCH 013/152] Pass statements object --- builtin/logical/database/dbs/cassandra.go | 10 ++-- builtin/logical/database/dbs/db.go | 13 +++-- builtin/logical/database/dbs/mysql.go | 13 ++--- builtin/logical/database/dbs/postgresql.go | 12 ++--- builtin/logical/database/path_role_create.go | 2 +- builtin/logical/database/path_roles.go | 50 +++++++++++++------- 6 files changed, 62 insertions(+), 38 deletions(-) diff --git a/builtin/logical/database/dbs/cassandra.go b/builtin/logical/database/dbs/cassandra.go index 9c5607e0d1b9..9956372d6fd2 100644 --- a/builtin/logical/database/dbs/cassandra.go +++ b/builtin/logical/database/dbs/cassandra.go @@ -34,7 +34,7 @@ func (c *Cassandra) getConnection() (*gocql.Session, error) { return session.(*gocql.Session), nil } -func (c *Cassandra) CreateUser(createStmts, rollbackStmts, username, password, expiration string) error { +func (c *Cassandra) CreateUser(statements Statements, username, password, expiration string) error { // Grab the lock c.Lock() defer c.Unlock() @@ -46,7 +46,7 @@ func (c *Cassandra) CreateUser(createStmts, rollbackStmts, username, password, e } // Execute each query - for _, query := range strutil.ParseArbitraryStringSlice(createStmts, ";") { + for _, query := range strutil.ParseArbitraryStringSlice(statements.CreationStatements, ";") { query = strings.TrimSpace(query) if len(query) == 0 { continue @@ -57,7 +57,7 @@ func (c *Cassandra) CreateUser(createStmts, rollbackStmts, username, password, e "password": password, })).Exec() if err != nil { - for _, query := range strutil.ParseArbitraryStringSlice(rollbackStmts, ";") { + for _, query := range strutil.ParseArbitraryStringSlice(statements.RollbackStatements, ";") { query = strings.TrimSpace(query) if len(query) == 0 { continue @@ -75,12 +75,12 @@ func (c *Cassandra) CreateUser(createStmts, rollbackStmts, username, password, e return nil } -func (c *Cassandra) RenewUser(username, expiration string) error { +func (c *Cassandra) RenewUser(statements Statements, username, expiration string) error { // NOOP return nil } -func (c *Cassandra) RevokeUser(username, revocationStmts string) error { +func (c *Cassandra) RevokeUser(statements Statements, username string) error { // Grab the lock c.Lock() defer c.Unlock() diff --git a/builtin/logical/database/dbs/db.go b/builtin/logical/database/dbs/db.go index 4c04c0fd4f9f..e3e8cb39b71b 100644 --- a/builtin/logical/database/dbs/db.go +++ b/builtin/logical/database/dbs/db.go @@ -78,9 +78,9 @@ func Factory(conf *DatabaseConfig) (DatabaseType, error) { type DatabaseType interface { Type() string - CreateUser(createStmt, rollbackStmt, username, password, expiration string) error - RenewUser(username, expiration string) error - RevokeUser(username, revocationStmt string) error + CreateUser(statements Statements, username, password, expiration string) error + RenewUser(statements Statements, username, expiration string) error + RevokeUser(statements Statements, username string) error ConnectionProducer CredentialsProducer @@ -94,6 +94,13 @@ type DatabaseConfig struct { MaxConnectionLifetime time.Duration `json:"max_connection_lifetime" structs:"max_connection_lifetime" mapstructure:"max_connection_lifetime"` } +type Statements struct { + CreationStatements string `json:"creation_statments" mapstructure:"creation_statements" structs:"creation_statments"` + RevocationStatements string `json:"revocation_statements" mapstructure:"revocation_statements" structs:"revocation_statements"` + RollbackStatements string `json:"rollback_statements" mapstructure:"rollback_statements" structs:"rollback_statements"` + RenewStatements string `json:"renew_statements" mapstructure:"renew_statements" structs:"renew_statements"` +} + // Query templates a query for us. func queryHelper(tpl string, data map[string]string) string { for k, v := range data { diff --git a/builtin/logical/database/dbs/mysql.go b/builtin/logical/database/dbs/mysql.go index 0cf77062cad3..0ff0154157bb 100644 --- a/builtin/logical/database/dbs/mysql.go +++ b/builtin/logical/database/dbs/mysql.go @@ -7,7 +7,7 @@ import ( "github.com/hashicorp/vault/helper/strutil" ) -const defaultRevocationStmts = ` +const defaultMysqlRevocationStmts = ` REVOKE ALL PRIVILEGES, GRANT OPTION FROM '{{name}}'@'%'; DROP USER '{{name}}'@'%' ` @@ -30,7 +30,7 @@ func (m *MySQL) getConnection() (*sql.DB, error) { return db.(*sql.DB), nil } -func (m *MySQL) CreateUser(createStmts, rollbackStmts, username, password, expiration string) error { +func (m *MySQL) CreateUser(statements Statements, username, password, expiration string) error { // Grab the lock m.Lock() defer m.Unlock() @@ -49,7 +49,7 @@ func (m *MySQL) CreateUser(createStmts, rollbackStmts, username, password, expir defer tx.Rollback() // Execute each query - for _, query := range strutil.ParseArbitraryStringSlice(createStmts, ";") { + for _, query := range strutil.ParseArbitraryStringSlice(statements.CreationStatements, ";") { query = strings.TrimSpace(query) if len(query) == 0 { continue @@ -77,11 +77,11 @@ func (m *MySQL) CreateUser(createStmts, rollbackStmts, username, password, expir } // NOOP -func (m *MySQL) RenewUser(username, expiration string) error { +func (m *MySQL) RenewUser(statements Statements, username, expiration string) error { return nil } -func (m *MySQL) RevokeUser(username, revocationStmts string) error { +func (m *MySQL) RevokeUser(statements Statements, username string) error { // Grab the read lock m.Lock() defer m.Unlock() @@ -92,9 +92,10 @@ func (m *MySQL) RevokeUser(username, revocationStmts string) error { return err } + revocationStmts := statements.RevocationStatements // Use a default SQL statement for revocation if one cannot be fetched from the role if revocationStmts == "" { - revocationStmts = defaultRevocationStmts + revocationStmts = defaultMysqlRevocationStmts } // Start a transaction diff --git a/builtin/logical/database/dbs/postgresql.go b/builtin/logical/database/dbs/postgresql.go index 468746fc4bf7..51b72ebc8a52 100644 --- a/builtin/logical/database/dbs/postgresql.go +++ b/builtin/logical/database/dbs/postgresql.go @@ -27,7 +27,7 @@ func (p *PostgreSQL) getConnection() (*sql.DB, error) { return db.(*sql.DB), nil } -func (p *PostgreSQL) CreateUser(createStmts, rollbackStmts, username, password, expiration string) error { +func (p *PostgreSQL) CreateUser(statements Statements, username, password, expiration string) error { // Grab the lock p.Lock() defer p.Unlock() @@ -51,7 +51,7 @@ func (p *PostgreSQL) CreateUser(createStmts, rollbackStmts, username, password, // Return the secret // Execute each query - for _, query := range strutil.ParseArbitraryStringSlice(createStmts, ";") { + for _, query := range strutil.ParseArbitraryStringSlice(statements.CreationStatements, ";") { query = strings.TrimSpace(query) if len(query) == 0 { continue @@ -83,7 +83,7 @@ func (p *PostgreSQL) CreateUser(createStmts, rollbackStmts, username, password, return nil } -func (p *PostgreSQL) RenewUser(username, expiration string) error { +func (p *PostgreSQL) RenewUser(statements Statements, username, expiration string) error { // Grab the lock p.Lock() defer p.Unlock() @@ -110,16 +110,16 @@ func (p *PostgreSQL) RenewUser(username, expiration string) error { return nil } -func (p *PostgreSQL) RevokeUser(username, revocationStmts string) error { +func (p *PostgreSQL) RevokeUser(statements Statements, username string) error { // Grab the lock p.Lock() defer p.Unlock() - if revocationStmts == "" { + if statements.RevocationStatements == "" { return p.defaultRevokeUser(username) } - return p.customRevokeUser(username, revocationStmts) + return p.customRevokeUser(username, statements.RevocationStatements) } func (p *PostgreSQL) customRevokeUser(username, revocationStmts string) error { diff --git a/builtin/logical/database/path_role_create.go b/builtin/logical/database/path_role_create.go index b1cce97f30e9..3f7a513c89bb 100644 --- a/builtin/logical/database/path_role_create.go +++ b/builtin/logical/database/path_role_create.go @@ -67,7 +67,7 @@ func (b *databaseBackend) pathRoleCreateRead(req *logical.Request, data *framewo expiration := db.GenerateExpiration(role.DefaultTTL) - err = db.CreateUser(role.CreationStatements, role.RollbackStatements, username, password, expiration) + err = db.CreateUser(role.Statements, username, password, expiration) if err != nil { return nil, err } diff --git a/builtin/logical/database/path_roles.go b/builtin/logical/database/path_roles.go index 994d084f0901..1268b05a0145 100644 --- a/builtin/logical/database/path_roles.go +++ b/builtin/logical/database/path_roles.go @@ -4,6 +4,7 @@ import ( "fmt" "time" + "github.com/hashicorp/vault/builtin/logical/database/dbs" "github.com/hashicorp/vault/logical" "github.com/hashicorp/vault/logical/framework" ) @@ -42,12 +43,18 @@ func pathRoles(b *databaseBackend) *framework.Path { "revocation_statements": { Type: framework.TypeString, - Description: `SQL statements to be executed to revoke a user. Must be a semicolon-separated + Description: `Statements to be executed to revoke a user. Must be a semicolon-separated + string, a base64-encoded semicolon-separated string, a serialized JSON string + array, or a base64-encoded serialized JSON string array. The '{{name}}' value + will be substituted.`, + }, + "renew_statements": { + Type: framework.TypeString, + Description: `Statements to be executed to renew a user. Must be a semicolon-separated string, a base64-encoded semicolon-separated string, a serialized JSON string array, or a base64-encoded serialized JSON string array. The '{{name}}' value will be substituted.`, }, - "rollback_statements": { Type: framework.TypeString, Description: `SQL statements to be executed to revoke a user. Must be a semicolon-separated @@ -98,9 +105,10 @@ func (b *databaseBackend) pathRoleRead(req *logical.Request, data *framework.Fie return &logical.Response{ Data: map[string]interface{}{ - "creation_statments": role.CreationStatements, - "revocation_statements": role.RevocationStatements, - "rollback_statements": role.RollbackStatements, + "creation_statments": role.Statements.CreationStatements, + "revocation_statements": role.Statements.RevocationStatements, + "rollback_statements": role.Statements.RollbackStatements, + "renew_statements": role.Statements.RenewStatements, "default_ttl": role.DefaultTTL.String(), "max_ttl": role.MaxTTL.String(), }, @@ -119,9 +127,14 @@ func (b *databaseBackend) pathRoleList(req *logical.Request, d *framework.FieldD func (b *databaseBackend) pathRoleCreate(req *logical.Request, data *framework.FieldData) (*logical.Response, error) { name := data.Get("name").(string) dbName := data.Get("db_name").(string) + + // Get statements creationStmts := data.Get("creation_statements").(string) revocationStmts := data.Get("revocation_statements").(string) rollbackStmts := data.Get("rollback_statements").(string) + renewStmts := data.Get("renew_statements").(string) + + // Get TTLs defaultTTLRaw := data.Get("default_ttl").(string) maxTTLRaw := data.Get("max_ttl").(string) @@ -136,16 +149,21 @@ func (b *databaseBackend) pathRoleCreate(req *logical.Request, data *framework.F "Invalid max_ttl: %s", err)), nil } + statements := dbs.Statements{ + CreationStatements: creationStmts, + RevocationStatements: revocationStmts, + RollbackStatements: rollbackStmts, + RenewStatements: rollbackStmts, + } + // TODO: Think about preparing the statments to test. // Store it entry, err := logical.StorageEntryJSON("role/"+name, &roleEntry{ - DBName: dbName, - CreationStatements: creationStmts, - RevocationStatements: revocationStmts, - RollbackStatements: rollbackStmts, - DefaultTTL: defaultTTL, - MaxTTL: maxTTL, + DBName: dbName, + Statements: statements, + DefaultTTL: defaultTTL, + MaxTTL: maxTTL, }) if err != nil { return nil, err @@ -158,12 +176,10 @@ func (b *databaseBackend) pathRoleCreate(req *logical.Request, data *framework.F } type roleEntry struct { - DBName string `json:"db_name" mapstructure:"db_name" structs:"db_name"` - CreationStatements string `json:"creation_statment" mapstructure:"creation_statement" structs:"creation_statment"` - RevocationStatements string `json:"revocation_statement" mapstructure:"revocation_statement" structs:"revocation_statement"` - RollbackStatements string `json:"revocation_statement" mapstructure:"revocation_statement" structs:"revocation_statement"` - DefaultTTL time.Duration `json:"default_ttl" mapstructure:"default_ttl" structs:"default_ttl"` - MaxTTL time.Duration `json:"max_ttl" mapstructure:"max_ttl" structs:"max_ttl"` + DBName string `json:"db_name" mapstructure:"db_name" structs:"db_name"` + Statements dbs.Statements `json:"statments" mapstructure:"statements" structs:"statments"` + DefaultTTL time.Duration `json:"default_ttl" mapstructure:"default_ttl" structs:"default_ttl"` + MaxTTL time.Duration `json:"max_ttl" mapstructure:"max_ttl" structs:"max_ttl"` } const pathRoleHelpSyn = ` From 73200db1d90d1950cbff783b5ff48489406e2bc2 Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Tue, 7 Mar 2017 17:00:52 -0800 Subject: [PATCH 014/152] Add defaults to the cassandra databse type --- builtin/logical/database/dbs/cassandra.go | 13 +++++++++++-- builtin/logical/database/dbs/db.go | 2 ++ 2 files changed, 13 insertions(+), 2 deletions(-) diff --git a/builtin/logical/database/dbs/cassandra.go b/builtin/logical/database/dbs/cassandra.go index 9956372d6fd2..1be26766bdcf 100644 --- a/builtin/logical/database/dbs/cassandra.go +++ b/builtin/logical/database/dbs/cassandra.go @@ -45,8 +45,17 @@ func (c *Cassandra) CreateUser(statements Statements, username, password, expira return err } + creationCQL := statements.CreationStatements + if creationCQL == "" { + creationCQL = defaultCreationCQL + } + rollbackCQL := statements.RollbackStatements + if rollbackCQL == "" { + rollbackCQL = defaultRollbackCQL + } + // Execute each query - for _, query := range strutil.ParseArbitraryStringSlice(statements.CreationStatements, ";") { + for _, query := range strutil.ParseArbitraryStringSlice(creationCQL, ";") { query = strings.TrimSpace(query) if len(query) == 0 { continue @@ -57,7 +66,7 @@ func (c *Cassandra) CreateUser(statements Statements, username, password, expira "password": password, })).Exec() if err != nil { - for _, query := range strutil.ParseArbitraryStringSlice(statements.RollbackStatements, ";") { + for _, query := range strutil.ParseArbitraryStringSlice(rollbackCQL, ";") { query = strings.TrimSpace(query) if len(query) == 0 { continue diff --git a/builtin/logical/database/dbs/db.go b/builtin/logical/database/dbs/db.go index e3e8cb39b71b..e173e2dd8ef0 100644 --- a/builtin/logical/database/dbs/db.go +++ b/builtin/logical/database/dbs/db.go @@ -94,6 +94,8 @@ type DatabaseConfig struct { MaxConnectionLifetime time.Duration `json:"max_connection_lifetime" structs:"max_connection_lifetime" mapstructure:"max_connection_lifetime"` } +// Statments set in role creation and passed into the database type's functions. +// TODO: Add a way of setting defaults here. type Statements struct { CreationStatements string `json:"creation_statments" mapstructure:"creation_statements" structs:"creation_statments"` RevocationStatements string `json:"revocation_statements" mapstructure:"revocation_statements" structs:"revocation_statements"` From cd68899a4ad6c7a6d69b2aed27f071d39da929c8 Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Tue, 7 Mar 2017 17:21:44 -0800 Subject: [PATCH 015/152] Fix renew and revoke calls --- builtin/logical/database/path_roles.go | 2 +- builtin/logical/database/secret_creds.go | 5 ++--- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/builtin/logical/database/path_roles.go b/builtin/logical/database/path_roles.go index 1268b05a0145..9a5bb9324dfd 100644 --- a/builtin/logical/database/path_roles.go +++ b/builtin/logical/database/path_roles.go @@ -153,7 +153,7 @@ func (b *databaseBackend) pathRoleCreate(req *logical.Request, data *framework.F CreationStatements: creationStmts, RevocationStatements: revocationStmts, RollbackStatements: rollbackStmts, - RenewStatements: rollbackStmts, + RenewStatements: renewStmts, } // TODO: Think about preparing the statments to test. diff --git a/builtin/logical/database/secret_creds.go b/builtin/logical/database/secret_creds.go index 120804e91ce3..90b88082eb44 100644 --- a/builtin/logical/database/secret_creds.go +++ b/builtin/logical/database/secret_creds.go @@ -70,7 +70,7 @@ func (b *databaseBackend) secretCredsRenew(req *logical.Request, d *framework.Fi if expireTime := resp.Secret.ExpirationTime(); !expireTime.IsZero() { expiration := expireTime.Format("2006-01-02 15:04:05-0700") - err := db.RenewUser(username, expiration) + err := db.RenewUser(role.Statements, username, expiration) if err != nil { return nil, err } @@ -87,7 +87,6 @@ func (b *databaseBackend) secretCredsRevoke(req *logical.Request, d *framework.F } username, ok := usernameRaw.(string) - var revocationSQL string var resp *logical.Response roleNameRaw, ok := req.Secret.InternalData["role"] @@ -129,7 +128,7 @@ func (b *databaseBackend) secretCredsRevoke(req *logical.Request, d *framework.F return nil, fmt.Errorf("Could not find database with name: %s", role.DBName) } - err = db.RevokeUser(username, revocationSQL) + err = db.RevokeUser(role.Statements, username) if err != nil { return nil, err } From 00359cdea4ccac728fb8b987a28c6c42275d13ce Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Wed, 8 Mar 2017 14:46:53 -0800 Subject: [PATCH 016/152] Update secrets fields --- builtin/logical/database/dbs/mysql.go | 5 ++++ .../database/path_config_connection.go | 30 +++++++++---------- builtin/logical/database/secret_creds.go | 14 ++------- 3 files changed, 22 insertions(+), 27 deletions(-) diff --git a/builtin/logical/database/dbs/mysql.go b/builtin/logical/database/dbs/mysql.go index 0ff0154157bb..0d8be2a470f3 100644 --- a/builtin/logical/database/dbs/mysql.go +++ b/builtin/logical/database/dbs/mysql.go @@ -2,6 +2,7 @@ package dbs import ( "database/sql" + "fmt" "strings" "github.com/hashicorp/vault/helper/strutil" @@ -41,6 +42,10 @@ func (m *MySQL) CreateUser(statements Statements, username, password, expiration return err } + if statements.CreationStatements == "" { + return fmt.Errorf("Empty creation statements") + } + // Start a transaction tx, err := db.Begin() if err != nil { diff --git a/builtin/logical/database/path_config_connection.go b/builtin/logical/database/path_config_connection.go index 085113fe98e7..c2fc085aec39 100644 --- a/builtin/logical/database/path_config_connection.go +++ b/builtin/logical/database/path_config_connection.go @@ -197,22 +197,22 @@ func (b *databaseBackend) pathConnectionWrite(req *logical.Request, data *framew b.connections[name] = db } - /* - // Don't check the connection_url if verification is disabled - verifyConnection := data.Get("verify_connection").(bool) - if verifyConnection { - // Verify the string - db, err := sql.Open("postgres", connURL) - if err != nil { - return logical.ErrorResponse(fmt.Sprintf( - "Error validating connection info: %s", err)), nil - } - defer db.Close() - if err := db.Ping(); err != nil { - return logical.ErrorResponse(fmt.Sprintf( - "Error validating connection info: %s", err)), nil - } + /* TODO: + // Don't check the connection_url if verification is disabled + verifyConnection := data.Get("verify_connection").(bool) + if verifyConnection { + // Verify the string + db, err := sql.Open("postgres", connURL) + if err != nil { + return logical.ErrorResponse(fmt.Sprintf( + "Error validating connection info: %s", err)), nil + } + defer db.Close() + if err := db.Ping(); err != nil { + return logical.ErrorResponse(fmt.Sprintf( + "Error validating connection info: %s", err)), nil } + } */ // Store it diff --git a/builtin/logical/database/secret_creds.go b/builtin/logical/database/secret_creds.go index 90b88082eb44..e39525a18c42 100644 --- a/builtin/logical/database/secret_creds.go +++ b/builtin/logical/database/secret_creds.go @@ -11,18 +11,8 @@ const SecretCredsType = "creds" func secretCreds(b *databaseBackend) *framework.Secret { return &framework.Secret{ - Type: SecretCredsType, - Fields: map[string]*framework.FieldSchema{ - "username": &framework.FieldSchema{ - Type: framework.TypeString, - Description: "Username", - }, - - "password": &framework.FieldSchema{ - Type: framework.TypeString, - Description: "Password", - }, - }, + Type: SecretCredsType, + Fields: map[string]*framework.FieldSchema{}, Renew: b.secretCredsRenew, Revoke: b.secretCredsRevoke, From d4ea6c17689ae2a6b61ff2f56b03d6fe21512f01 Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Thu, 9 Mar 2017 17:43:37 -0800 Subject: [PATCH 017/152] Add plugin features --- .../logical/database/dbs/credentialsproducer.go | 10 +++++----- builtin/logical/database/dbs/db.go | 16 +++++++++++++++- .../logical/database/path_config_connection.go | 10 ++++++++++ builtin/logical/database/path_role_create.go | 5 ++++- 4 files changed, 34 insertions(+), 7 deletions(-) diff --git a/builtin/logical/database/dbs/credentialsproducer.go b/builtin/logical/database/dbs/credentialsproducer.go index 94fce6275a2e..5ae3b128e6c2 100644 --- a/builtin/logical/database/dbs/credentialsproducer.go +++ b/builtin/logical/database/dbs/credentialsproducer.go @@ -11,7 +11,7 @@ import ( type CredentialsProducer interface { GenerateUsername(displayName string) (string, error) GeneratePassword() (string, error) - GenerateExpiration(ttl time.Duration) string + GenerateExpiration(ttl time.Duration) (string, error) } // sqlCredentialsProducer impliments CredentialsProducer and provides a generic credentials producer for most sql database types. @@ -46,10 +46,10 @@ func (scp *sqlCredentialsProducer) GeneratePassword() (string, error) { return password, nil } -func (scp *sqlCredentialsProducer) GenerateExpiration(ttl time.Duration) string { +func (scp *sqlCredentialsProducer) GenerateExpiration(ttl time.Duration) (string, error) { return time.Now(). Add(ttl). - Format("2006-01-02 15:04:05-0700") + Format("2006-01-02 15:04:05-0700"), nil } type cassandraCredentialsProducer struct{} @@ -74,6 +74,6 @@ func (ccp *cassandraCredentialsProducer) GeneratePassword() (string, error) { return password, nil } -func (ccp *cassandraCredentialsProducer) GenerateExpiration(ttl time.Duration) string { - return "" +func (ccp *cassandraCredentialsProducer) GenerateExpiration(ttl time.Duration) (string, error) { + return "", nil } diff --git a/builtin/logical/database/dbs/db.go b/builtin/logical/database/dbs/db.go index e173e2dd8ef0..063cc89cf1f8 100644 --- a/builtin/logical/database/dbs/db.go +++ b/builtin/logical/database/dbs/db.go @@ -13,6 +13,7 @@ const ( postgreSQLTypeName = "postgres" mySQLTypeName = "mysql" cassandraTypeName = "cassandra" + pluginTypeName = "plugin" ) var ( @@ -71,6 +72,18 @@ func Factory(conf *DatabaseConfig) (DatabaseType, error) { ConnectionProducer: connProducer, CredentialsProducer: credsProducer, }, nil + + case pluginTypeName: + if conf.PluginCommand == "" { + return nil, errors.New("ERROR") + } + + db, err := newPluginClient(conf.PluginCommand) + if err != nil { + return nil, err + } + + return db, nil } return nil, ErrUnsupportedDatabaseType @@ -82,7 +95,7 @@ type DatabaseType interface { RenewUser(statements Statements, username, expiration string) error RevokeUser(statements Statements, username string) error - ConnectionProducer + Close() CredentialsProducer } @@ -92,6 +105,7 @@ type DatabaseConfig struct { MaxOpenConnections int `json:"max_open_connections" structs:"max_open_connections" mapstructure:"max_open_connections"` MaxIdleConnections int `json:"max_idle_connections" structs:"max_idle_connections" mapstructure:"max_idle_connections"` MaxConnectionLifetime time.Duration `json:"max_connection_lifetime" structs:"max_connection_lifetime" mapstructure:"max_connection_lifetime"` + PluginCommand string `json:"plugin_command" structs:"plugin_command" mapstructure:"plugin_command"` } // Statments set in role creation and passed into the database type's functions. diff --git a/builtin/logical/database/path_config_connection.go b/builtin/logical/database/path_config_connection.go index c2fc085aec39..4e1da240c8ec 100644 --- a/builtin/logical/database/path_config_connection.go +++ b/builtin/logical/database/path_config_connection.go @@ -111,6 +111,12 @@ reduced to the same size.`, Description: `Maximum amount of time a connection may be reused; a zero or negative value reuses connections forever.`, }, + + "plugin_command": &framework.FieldSchema{ + Type: framework.TypeString, + Description: `Maximum amount of time a connection may be reused; + a zero or negative value reuses connections forever.`, + }, }, Callbacks: map[logical.Operation]framework.OperationFunc{ @@ -146,6 +152,9 @@ func (b *databaseBackend) pathConnectionRead(req *logical.Request, data *framewo func (b *databaseBackend) pathConnectionWrite(req *logical.Request, data *framework.FieldData) (*logical.Response, error) { connType := data.Get("connection_type").(string) + if connType == "" { + return logical.ErrorResponse("connection_type not set"), nil + } maxOpenConns := data.Get("max_open_connections").(int) if maxOpenConns == 0 { @@ -173,6 +182,7 @@ func (b *databaseBackend) pathConnectionWrite(req *logical.Request, data *framew MaxOpenConnections: maxOpenConns, MaxIdleConnections: maxIdleConns, MaxConnectionLifetime: maxConnLifetime, + PluginCommand: data.Get("plugin_command").(string), } name := data.Get("name").(string) diff --git a/builtin/logical/database/path_role_create.go b/builtin/logical/database/path_role_create.go index 3f7a513c89bb..c7989c25d870 100644 --- a/builtin/logical/database/path_role_create.go +++ b/builtin/logical/database/path_role_create.go @@ -65,7 +65,10 @@ func (b *databaseBackend) pathRoleCreateRead(req *logical.Request, data *framewo return nil, err } - expiration := db.GenerateExpiration(role.DefaultTTL) + expiration, err := db.GenerateExpiration(role.DefaultTTL) + if err != nil { + return nil, err + } err = db.CreateUser(role.Statements, username, password, expiration) if err != nil { From 3766ab14e5e50b5b1c37f4a5dd6c8f08192bc0f9 Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Thu, 9 Mar 2017 17:43:58 -0800 Subject: [PATCH 018/152] Add plugin file --- builtin/logical/database/dbs/plugin.go | 242 +++++++++++++++++++++++++ 1 file changed, 242 insertions(+) create mode 100644 builtin/logical/database/dbs/plugin.go diff --git a/builtin/logical/database/dbs/plugin.go b/builtin/logical/database/dbs/plugin.go new file mode 100644 index 000000000000..e495dbf14780 --- /dev/null +++ b/builtin/logical/database/dbs/plugin.go @@ -0,0 +1,242 @@ +package dbs + +import ( + "net/rpc" + "os/exec" + "sync" + "time" + + "github.com/hashicorp/go-plugin" +) + +// handshakeConfigs are used to just do a basic handshake between +// a plugin and host. If the handshake fails, a user friendly error is shown. +// This prevents users from executing bad plugins or executing a plugin +// directory. It is a UX feature, not a security feature. +var handshakeConfig = plugin.HandshakeConfig{ + ProtocolVersion: 1, + MagicCookieKey: "BASIC_PLUGIN", + MagicCookieValue: "hello", +} + +type DatabasePlugin struct { + impl DatabaseType +} + +func (d DatabasePlugin) Server(*plugin.MuxBroker) (interface{}, error) { + return &databasePluginRPCServer{impl: d.impl}, nil +} + +func (DatabasePlugin) Client(b *plugin.MuxBroker, c *rpc.Client) (interface{}, error) { + return &databasePluginRPCClient{client: c}, nil +} + +type DatabasePluginClient struct { + client *plugin.Client + sync.Mutex + + *databasePluginRPCClient +} + +func (dc *DatabasePluginClient) Close() { + dc.databasePluginRPCClient.Close() + + dc.client.Kill() +} + +func newPluginClient(command string) (DatabaseType, error) { + // pluginMap is the map of plugins we can dispense. + var pluginMap = map[string]plugin.Plugin{ + "database": new(DatabasePlugin), + } + + client := plugin.NewClient(&plugin.ClientConfig{ + HandshakeConfig: handshakeConfig, + Plugins: pluginMap, + Cmd: exec.Command(command), + }) + + // Connect via RPC + rpcClient, err := client.Client() + if err != nil { + return nil, err + } + + // Request the plugin + raw, err := rpcClient.Dispense("database") + if err != nil { + return nil, err + } + + // We should have a Greeter now! This feels like a normal interface + // implementation but is in fact over an RPC connection. + databaseRPC := raw.(*databasePluginRPCClient) + + return &DatabasePluginClient{ + client: client, + databasePluginRPCClient: databaseRPC, + }, nil +} + +func NewPluginServer(db DatabaseType) { + dbPlugin := &DatabasePlugin{ + impl: db, + } + + // pluginMap is the map of plugins we can dispense. + var pluginMap = map[string]plugin.Plugin{ + "database": dbPlugin, + } + + plugin.Serve(&plugin.ServeConfig{ + HandshakeConfig: handshakeConfig, + Plugins: pluginMap, + }) +} + +// ---- RPC client domain ---- + +type databasePluginRPCClient struct { + client *rpc.Client +} + +func (dr *databasePluginRPCClient) Type() string { + return "plugin" +} + +func (dr *databasePluginRPCClient) CreateUser(statements Statements, username, password, expiration string) error { + req := CreateUserRequest{ + Statements: statements, + Username: username, + Password: password, + Expiration: expiration, + } + + err := dr.client.Call("Plugin.CreateUser", req, &struct{}{}) + + return err +} + +func (dr *databasePluginRPCClient) RenewUser(statements Statements, username, expiration string) error { + req := RenewUserRequest{ + Statements: statements, + Username: username, + Expiration: expiration, + } + + err := dr.client.Call("Plugin.RenewUser", req, &struct{}{}) + + return err +} + +func (dr *databasePluginRPCClient) RevokeUser(statements Statements, username string) error { + req := RevokeUserRequest{ + Statements: statements, + Username: username, + } + + err := dr.client.Call("Plugin.RevokeUser", req, &struct{}{}) + + return err +} + +func (dr *databasePluginRPCClient) Close() error { + err := dr.client.Call("Plugin.Close", struct{}{}, &struct{}{}) + + return err +} + +func (dr *databasePluginRPCClient) GenerateUsername(displayName string) (string, error) { + var username string + err := dr.client.Call("Plugin.GenerateUsername", displayName, &username) + + return username, err +} + +func (dr *databasePluginRPCClient) GeneratePassword() (string, error) { + var password string + err := dr.client.Call("Plugin.GeneratePassword", struct{}{}, &password) + + return password, err +} + +func (dr *databasePluginRPCClient) GenerateExpiration(duration time.Duration) (string, error) { + var expiration string + err := dr.client.Call("Plugin.GenerateExpiration", duration, &expiration) + + return expiration, err +} + +// ---- RPC server domain ---- +type databasePluginRPCServer struct { + impl DatabaseType +} + +func (ds *databasePluginRPCServer) Type(_ struct{}, resp *string) error { + *resp = "string" + return nil +} + +func (ds *databasePluginRPCServer) CreateUser(args *CreateUserRequest, _ *struct{}) error { + err := ds.impl.CreateUser(args.Statements, args.Username, args.Password, args.Expiration) + + return err +} + +func (ds *databasePluginRPCServer) RenewUser(args *RenewUserRequest, _ *struct{}) error { + err := ds.impl.RenewUser(args.Statements, args.Username, args.Expiration) + + return err +} + +func (ds *databasePluginRPCServer) RevokeUser(args *RevokeUserRequest, _ *struct{}) error { + err := ds.impl.RevokeUser(args.Statements, args.Username) + + return err +} + +func (ds *databasePluginRPCServer) Close(_ interface{}, _ *struct{}) error { + ds.impl.Close() + return nil +} + +func (ds *databasePluginRPCServer) GenerateUsername(args string, resp *string) error { + var err error + *resp, err = ds.impl.GenerateUsername(args) + + return err +} + +func (ds *databasePluginRPCServer) GeneratePassword(_ struct{}, resp *string) error { + var err error + *resp, err = ds.impl.GeneratePassword() + + return err +} + +func (ds *databasePluginRPCServer) GenerateExpiration(args time.Duration, resp *string) error { + var err error + *resp, err = ds.impl.GenerateExpiration(args) + + return err +} + +// ---- Request Args domain ---- + +type CreateUserRequest struct { + Statements Statements + Username string + Password string + Expiration string +} + +type RenewUserRequest struct { + Statements Statements + Username string + Expiration string +} + +type RevokeUserRequest struct { + Statements Statements + Username string +} From b63147b7c2c1641cb93d3073689c46964b1f95e0 Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Thu, 9 Mar 2017 21:31:29 -0800 Subject: [PATCH 019/152] Add special path to enforce root on plugin configuration --- builtin/logical/database/backend.go | 9 +- builtin/logical/database/dbs/db.go | 33 ++- .../database/path_config_connection.go | 201 ++++++++++-------- 3 files changed, 138 insertions(+), 105 deletions(-) diff --git a/builtin/logical/database/backend.go b/builtin/logical/database/backend.go index fe853d3fb4c5..e06e7b381b79 100644 --- a/builtin/logical/database/backend.go +++ b/builtin/logical/database/backend.go @@ -20,8 +20,15 @@ func Backend(conf *logical.BackendConfig) *databaseBackend { b.Backend = &framework.Backend{ Help: strings.TrimSpace(backendHelp), + PathsSpecial: &logical.Paths{ + Root: []string{ + "dbs/plugin/*", + }, + }, + Paths: []*framework.Path{ - pathConfigConnection(&b), + pathConfigureConnection(&b), + pathConfigurePluginConnection(&b), pathListRoles(&b), pathRoles(&b), pathRoleCreate(&b), diff --git a/builtin/logical/database/dbs/db.go b/builtin/logical/database/dbs/db.go index 063cc89cf1f8..bf78d29e6530 100644 --- a/builtin/logical/database/dbs/db.go +++ b/builtin/logical/database/dbs/db.go @@ -20,7 +20,9 @@ var ( ErrUnsupportedDatabaseType = errors.New("Unsupported database type") ) -func Factory(conf *DatabaseConfig) (DatabaseType, error) { +type Factory func(*DatabaseConfig) (DatabaseType, error) + +func BuiltinFactory(conf *DatabaseConfig) (DatabaseType, error) { switch conf.DatabaseType { case postgreSQLTypeName: var connProducer *sqlConnectionProducer @@ -72,21 +74,22 @@ func Factory(conf *DatabaseConfig) (DatabaseType, error) { ConnectionProducer: connProducer, CredentialsProducer: credsProducer, }, nil + } - case pluginTypeName: - if conf.PluginCommand == "" { - return nil, errors.New("ERROR") - } + return nil, ErrUnsupportedDatabaseType +} - db, err := newPluginClient(conf.PluginCommand) - if err != nil { - return nil, err - } +func PluginFactory(conf *DatabaseConfig) (DatabaseType, error) { + if conf.PluginCommand == "" { + return nil, errors.New("ERROR") + } - return db, nil + db, err := newPluginClient(conf.PluginCommand) + if err != nil { + return nil, err } - return nil, ErrUnsupportedDatabaseType + return db, nil } type DatabaseType interface { @@ -108,6 +111,14 @@ type DatabaseConfig struct { PluginCommand string `json:"plugin_command" structs:"plugin_command" mapstructure:"plugin_command"` } +func (dc *DatabaseConfig) GetFactory() Factory { + if dc.DatabaseType == pluginTypeName { + return PluginFactory + } + + return BuiltinFactory +} + // Statments set in role creation and passed into the database type's functions. // TODO: Add a way of setting defaults here. type Statements struct { diff --git a/builtin/logical/database/path_config_connection.go b/builtin/logical/database/path_config_connection.go index 4e1da240c8ec..4780dc492b7c 100644 --- a/builtin/logical/database/path_config_connection.go +++ b/builtin/logical/database/path_config_connection.go @@ -59,7 +59,10 @@ func (b *databaseBackend) pathConnectionReset(req *logical.Request, data *framew } db.Close() - db, err = dbs.Factory(&config) + + factory := config.GetFactory() + + db, err = factory(&config) if err != nil { return logical.ErrorResponse(fmt.Sprintf("Error creating database object: %s", err)), nil } @@ -69,9 +72,17 @@ func (b *databaseBackend) pathConnectionReset(req *logical.Request, data *framew return nil, nil } -func pathConfigConnection(b *databaseBackend) *framework.Path { +func pathConfigureConnection(b *databaseBackend) *framework.Path { + return buildConfigConnectionPath("dbs/%s", b.connectionWriteHandler(dbs.BuiltinFactory), b.connectionReadHandler()) +} + +func pathConfigurePluginConnection(b *databaseBackend) *framework.Path { + return buildConfigConnectionPath("dbs/plugin/%s", b.connectionWriteHandler(dbs.PluginFactory), b.connectionReadHandler()) +} + +func buildConfigConnectionPath(path string, updateOp, readOp framework.OperationFunc) *framework.Path { return &framework.Path{ - Pattern: fmt.Sprintf("dbs/%s", framework.GenericNameRegex("name")), + Pattern: fmt.Sprintf(path, framework.GenericNameRegex("name")), Fields: map[string]*framework.FieldSchema{ "name": &framework.FieldSchema{ Type: framework.TypeString, @@ -120,8 +131,8 @@ reduced to the same size.`, }, Callbacks: map[logical.Operation]framework.OperationFunc{ - logical.UpdateOperation: b.pathConnectionWrite, - logical.ReadOperation: b.pathConnectionRead, + logical.UpdateOperation: updateOp, + logical.ReadOperation: readOp, }, HelpSynopsis: pathConfigConnectionHelpSyn, @@ -130,115 +141,119 @@ reduced to the same size.`, } // pathConnectionRead reads out the connection configuration -func (b *databaseBackend) pathConnectionRead(req *logical.Request, data *framework.FieldData) (*logical.Response, error) { - name := data.Get("name").(string) +func (b *databaseBackend) connectionReadHandler() framework.OperationFunc { + return func(req *logical.Request, data *framework.FieldData) (*logical.Response, error) { + name := data.Get("name").(string) - entry, err := req.Storage.Get(fmt.Sprintf("dbs/%s", name)) - if err != nil { - return nil, fmt.Errorf("failed to read connection configuration") - } - if entry == nil { - return nil, nil - } + entry, err := req.Storage.Get(fmt.Sprintf("dbs/%s", name)) + if err != nil { + return nil, fmt.Errorf("failed to read connection configuration") + } + if entry == nil { + return nil, nil + } - var config dbs.DatabaseConfig - if err := entry.DecodeJSON(&config); err != nil { - return nil, err + var config dbs.DatabaseConfig + if err := entry.DecodeJSON(&config); err != nil { + return nil, err + } + return &logical.Response{ + Data: structs.New(config).Map(), + }, nil } - return &logical.Response{ - Data: structs.New(config).Map(), - }, nil } -func (b *databaseBackend) pathConnectionWrite(req *logical.Request, data *framework.FieldData) (*logical.Response, error) { - connType := data.Get("connection_type").(string) - if connType == "" { - return logical.ErrorResponse("connection_type not set"), nil - } +func (b *databaseBackend) connectionWriteHandler(factory dbs.Factory) framework.OperationFunc { + return func(req *logical.Request, data *framework.FieldData) (*logical.Response, error) { + connType := data.Get("connection_type").(string) + if connType == "" { + return logical.ErrorResponse("connection_type not set"), nil + } - maxOpenConns := data.Get("max_open_connections").(int) - if maxOpenConns == 0 { - maxOpenConns = 2 - } + maxOpenConns := data.Get("max_open_connections").(int) + if maxOpenConns == 0 { + maxOpenConns = 2 + } - maxIdleConns := data.Get("max_idle_connections").(int) - if maxIdleConns == 0 { - maxIdleConns = maxOpenConns - } - if maxIdleConns > maxOpenConns { - maxIdleConns = maxOpenConns - } + maxIdleConns := data.Get("max_idle_connections").(int) + if maxIdleConns == 0 { + maxIdleConns = maxOpenConns + } + if maxIdleConns > maxOpenConns { + maxIdleConns = maxOpenConns + } - maxConnLifetimeRaw := data.Get("max_connection_lifetime").(string) - maxConnLifetime, err := time.ParseDuration(maxConnLifetimeRaw) - if err != nil { - return logical.ErrorResponse(fmt.Sprintf( - "Invalid max_connection_lifetime: %s", err)), nil - } + maxConnLifetimeRaw := data.Get("max_connection_lifetime").(string) + maxConnLifetime, err := time.ParseDuration(maxConnLifetimeRaw) + if err != nil { + return logical.ErrorResponse(fmt.Sprintf( + "Invalid max_connection_lifetime: %s", err)), nil + } - config := &dbs.DatabaseConfig{ - DatabaseType: connType, - ConnectionDetails: data.Raw, - MaxOpenConnections: maxOpenConns, - MaxIdleConnections: maxIdleConns, - MaxConnectionLifetime: maxConnLifetime, - PluginCommand: data.Get("plugin_command").(string), - } + config := &dbs.DatabaseConfig{ + DatabaseType: connType, + ConnectionDetails: data.Raw, + MaxOpenConnections: maxOpenConns, + MaxIdleConnections: maxIdleConns, + MaxConnectionLifetime: maxConnLifetime, + PluginCommand: data.Get("plugin_command").(string), + } - name := data.Get("name").(string) + name := data.Get("name").(string) - // Grab the mutex lock - b.Lock() - defer b.Unlock() + // Grab the mutex lock + b.Lock() + defer b.Unlock() - var db dbs.DatabaseType - if _, ok := b.connections[name]; ok { + var db dbs.DatabaseType + if _, ok := b.connections[name]; ok { - // Don't allow the connection type to change - if b.connections[name].Type() != connType { - return logical.ErrorResponse("Can not change type of existing connection."), nil - } - } else { - db, err = dbs.Factory(config) - if err != nil { - return logical.ErrorResponse(fmt.Sprintf("Error creating database object: %s", err)), nil + // Don't allow the connection type to change + if b.connections[name].Type() != connType { + return logical.ErrorResponse("Can not change type of existing connection."), nil + } + } else { + db, err = factory(config) + if err != nil { + return logical.ErrorResponse(fmt.Sprintf("Error creating database object: %s", err)), nil + } + + b.connections[name] = db } - b.connections[name] = db - } + /* TODO: + // Don't check the connection_url if verification is disabled + verifyConnection := data.Get("verify_connection").(bool) + if verifyConnection { + // Verify the string + db, err := sql.Open("postgres", connURL) + if err != nil { + return logical.ErrorResponse(fmt.Sprintf( + "Error validating connection info: %s", err)), nil + } + defer db.Close() + if err := db.Ping(); err != nil { + return logical.ErrorResponse(fmt.Sprintf( + "Error validating connection info: %s", err)), nil + } + } + */ - /* TODO: - // Don't check the connection_url if verification is disabled - verifyConnection := data.Get("verify_connection").(bool) - if verifyConnection { - // Verify the string - db, err := sql.Open("postgres", connURL) + // Store it + entry, err := logical.StorageEntryJSON(fmt.Sprintf("dbs/%s", name), config) if err != nil { - return logical.ErrorResponse(fmt.Sprintf( - "Error validating connection info: %s", err)), nil + return nil, err } - defer db.Close() - if err := db.Ping(); err != nil { - return logical.ErrorResponse(fmt.Sprintf( - "Error validating connection info: %s", err)), nil + if err := req.Storage.Put(entry); err != nil { + return nil, err } - } - */ - // Store it - entry, err := logical.StorageEntryJSON(fmt.Sprintf("dbs/%s", name), config) - if err != nil { - return nil, err - } - if err := req.Storage.Put(entry); err != nil { - return nil, err - } - - // Reset the DB connection - resp := &logical.Response{} - resp.AddWarning("Read access to this endpoint should be controlled via ACLs as it will return the connection string or URL as it is, including passwords, if any.") + // Reset the DB connection + resp := &logical.Response{} + resp.AddWarning("Read access to this endpoint should be controlled via ACLs as it will return the connection string or URL as it is, including passwords, if any.") - return resp, nil + return resp, nil + } } const pathConfigConnectionHelpSyn = ` From 72a878b180a16651226ee629e6eb808eb3891a1c Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Thu, 9 Mar 2017 22:35:45 -0800 Subject: [PATCH 020/152] Rename reset to close --- builtin/logical/database/backend.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/builtin/logical/database/backend.go b/builtin/logical/database/backend.go index e06e7b381b79..69d91f6f278e 100644 --- a/builtin/logical/database/backend.go +++ b/builtin/logical/database/backend.go @@ -39,7 +39,7 @@ func Backend(conf *logical.BackendConfig) *databaseBackend { secretCreds(&b), }, - Clean: b.resetAllDBs, + Clean: b.closeAllDBs, } b.logger = conf.Logger @@ -56,7 +56,7 @@ type databaseBackend struct { } // resetAllDBs closes all connections from all database types -func (b *databaseBackend) resetAllDBs() { +func (b *databaseBackend) closeAllDBs() { b.logger.Trace("postgres/resetdb: enter") defer b.logger.Trace("postgres/resetdb: exit") From a0d207e254f66a8c1bae8005b79133cec9ca368b Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Fri, 10 Mar 2017 14:10:42 -0800 Subject: [PATCH 021/152] Add checksum attribute --- builtin/logical/database/dbs/db.go | 7 ++++++- builtin/logical/database/dbs/plugin.go | 2 +- builtin/logical/database/path_config_connection.go | 7 +++++++ 3 files changed, 14 insertions(+), 2 deletions(-) diff --git a/builtin/logical/database/dbs/db.go b/builtin/logical/database/dbs/db.go index bf78d29e6530..33cf7361a181 100644 --- a/builtin/logical/database/dbs/db.go +++ b/builtin/logical/database/dbs/db.go @@ -84,7 +84,11 @@ func PluginFactory(conf *DatabaseConfig) (DatabaseType, error) { return nil, errors.New("ERROR") } - db, err := newPluginClient(conf.PluginCommand) + if conf.PluginChecksum == "" { + return nil, errors.New("ERROR") + } + + db, err := newPluginClient(conf.PluginCommand, conf.PluginChecksum) if err != nil { return nil, err } @@ -109,6 +113,7 @@ type DatabaseConfig struct { MaxIdleConnections int `json:"max_idle_connections" structs:"max_idle_connections" mapstructure:"max_idle_connections"` MaxConnectionLifetime time.Duration `json:"max_connection_lifetime" structs:"max_connection_lifetime" mapstructure:"max_connection_lifetime"` PluginCommand string `json:"plugin_command" structs:"plugin_command" mapstructure:"plugin_command"` + PluginChecksum string `json:"plugin_checksum" structs:"plugin_checksum" mapstructure:"plugin_checksum"` } func (dc *DatabaseConfig) GetFactory() Factory { diff --git a/builtin/logical/database/dbs/plugin.go b/builtin/logical/database/dbs/plugin.go index e495dbf14780..bbd8d4ce4c5e 100644 --- a/builtin/logical/database/dbs/plugin.go +++ b/builtin/logical/database/dbs/plugin.go @@ -44,7 +44,7 @@ func (dc *DatabasePluginClient) Close() { dc.client.Kill() } -func newPluginClient(command string) (DatabaseType, error) { +func newPluginClient(command, checksum string) (DatabaseType, error) { // pluginMap is the map of plugins we can dispense. var pluginMap = map[string]plugin.Plugin{ "database": new(DatabasePlugin), diff --git a/builtin/logical/database/path_config_connection.go b/builtin/logical/database/path_config_connection.go index 4780dc492b7c..31f6182817eb 100644 --- a/builtin/logical/database/path_config_connection.go +++ b/builtin/logical/database/path_config_connection.go @@ -128,6 +128,12 @@ reduced to the same size.`, Description: `Maximum amount of time a connection may be reused; a zero or negative value reuses connections forever.`, }, + + "plugin_checksum": &framework.FieldSchema{ + Type: framework.TypeString, + Description: `Maximum amount of time a connection may be reused; + a zero or negative value reuses connections forever.`, + }, }, Callbacks: map[logical.Operation]framework.OperationFunc{ @@ -197,6 +203,7 @@ func (b *databaseBackend) connectionWriteHandler(factory dbs.Factory) framework. MaxIdleConnections: maxIdleConns, MaxConnectionLifetime: maxConnLifetime, PluginCommand: data.Get("plugin_command").(string), + PluginChecksum: data.Get("plugin_checksum").(string), } name := data.Get("name").(string) From c111b02568f895e8712619cc995f6bc52bc8127f Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Mon, 13 Mar 2017 14:39:55 -0700 Subject: [PATCH 022/152] Add a way to initalize plugins and builtin databases the same way. --- .../database/dbs/connectionproducer.go | 54 +++++++++++++++++-- builtin/logical/database/dbs/db.go | 21 ++------ builtin/logical/database/dbs/plugin.go | 12 +++++ .../database/path_config_connection.go | 24 +++++++++ 4 files changed, 90 insertions(+), 21 deletions(-) diff --git a/builtin/logical/database/dbs/connectionproducer.go b/builtin/logical/database/dbs/connectionproducer.go index 82da37cc7311..8d05e5d9e6b4 100644 --- a/builtin/logical/database/dbs/connectionproducer.go +++ b/builtin/logical/database/dbs/connectionproducer.go @@ -3,6 +3,7 @@ package dbs import ( "crypto/tls" "database/sql" + "errors" "fmt" "strings" "sync" @@ -11,14 +12,20 @@ import ( // Import sql drivers _ "github.com/go-sql-driver/mysql" _ "github.com/lib/pq" + "github.com/mitchellh/mapstructure" "github.com/gocql/gocql" "github.com/hashicorp/vault/helper/certutil" "github.com/hashicorp/vault/helper/tlsutil" ) +var ( + errNotInitalized = errors.New("Connection has not been initalized") +) + type ConnectionProducer interface { Close() + Initialize(map[string]interface{}) error sync.Locker connection() (interface{}, error) @@ -30,10 +37,28 @@ type sqlConnectionProducer struct { config *DatabaseConfig - db *sql.DB + initalized bool + db *sql.DB sync.Mutex } +func (c *sqlConnectionProducer) Initialize(conf map[string]interface{}) error { + c.Lock() + defer c.Unlock() + + err := mapstructure.Decode(conf, c) + if err != nil { + return err + } + c.initalized = true + + if _, err := c.connection(); err != nil { + return fmt.Errorf("Error Initalizing Connection: %s", err) + } + + return nil +} + func (c *sqlConnectionProducer) connection() (interface{}, error) { // If we already have a DB, test it and return if c.db != nil { @@ -98,13 +123,34 @@ type cassandraConnectionProducer struct { TLSMinVersion string `json:"tls_min_version" structs:"tls_min_version" mapstructure:"tls_min_version"` Consistency string `json:"consistency" structs:"consistency" mapstructure:"consistency"` - config *DatabaseConfig - - session *gocql.Session + config *DatabaseConfig + initalized bool + session *gocql.Session sync.Mutex } +func (c *cassandraConnectionProducer) Initialize(conf map[string]interface{}) error { + c.Lock() + defer c.Unlock() + + err := mapstructure.Decode(conf, c) + if err != nil { + return err + } + c.initalized = true + + if _, err := c.connection(); err != nil { + return fmt.Errorf("Error Initalizing Connection: %s", err) + } + + return nil +} + func (c *cassandraConnectionProducer) connection() (interface{}, error) { + if !c.initalized { + return nil, errNotInitalized + } + // If we already have a DB, return it if c.session != nil { return c.session, nil diff --git a/builtin/logical/database/dbs/db.go b/builtin/logical/database/dbs/db.go index 33cf7361a181..98443f8f2948 100644 --- a/builtin/logical/database/dbs/db.go +++ b/builtin/logical/database/dbs/db.go @@ -5,8 +5,6 @@ import ( "fmt" "strings" "time" - - "github.com/mitchellh/mapstructure" ) const ( @@ -25,11 +23,7 @@ type Factory func(*DatabaseConfig) (DatabaseType, error) func BuiltinFactory(conf *DatabaseConfig) (DatabaseType, error) { switch conf.DatabaseType { case postgreSQLTypeName: - var connProducer *sqlConnectionProducer - err := mapstructure.Decode(conf.ConnectionDetails, &connProducer) - if err != nil { - return nil, err - } + connProducer := &sqlConnectionProducer{} connProducer.config = conf credsProducer := &sqlCredentialsProducer{ @@ -43,11 +37,7 @@ func BuiltinFactory(conf *DatabaseConfig) (DatabaseType, error) { }, nil case mySQLTypeName: - var connProducer *sqlConnectionProducer - err := mapstructure.Decode(conf.ConnectionDetails, &connProducer) - if err != nil { - return nil, err - } + connProducer := &sqlConnectionProducer{} connProducer.config = conf credsProducer := &sqlCredentialsProducer{ @@ -61,11 +51,7 @@ func BuiltinFactory(conf *DatabaseConfig) (DatabaseType, error) { }, nil case cassandraTypeName: - var connProducer *cassandraConnectionProducer - err := mapstructure.Decode(conf.ConnectionDetails, &connProducer) - if err != nil { - return nil, err - } + connProducer := &cassandraConnectionProducer{} connProducer.config = conf credsProducer := &cassandraCredentialsProducer{} @@ -102,6 +88,7 @@ type DatabaseType interface { RenewUser(statements Statements, username, expiration string) error RevokeUser(statements Statements, username string) error + Initialize(map[string]interface{}) error Close() CredentialsProducer } diff --git a/builtin/logical/database/dbs/plugin.go b/builtin/logical/database/dbs/plugin.go index bbd8d4ce4c5e..b244a33fc667 100644 --- a/builtin/logical/database/dbs/plugin.go +++ b/builtin/logical/database/dbs/plugin.go @@ -140,6 +140,12 @@ func (dr *databasePluginRPCClient) RevokeUser(statements Statements, username st return err } +func (dr *databasePluginRPCClient) Initialize(conf map[string]interface{}) error { + err := dr.client.Call("Plugin.Initialize", conf, &struct{}{}) + + return err +} + func (dr *databasePluginRPCClient) Close() error { err := dr.client.Call("Plugin.Close", struct{}{}, &struct{}{}) @@ -195,6 +201,12 @@ func (ds *databasePluginRPCServer) RevokeUser(args *RevokeUserRequest, _ *struct return err } +func (ds *databasePluginRPCServer) Initialize(args map[string]interface{}, _ *struct{}) error { + err := ds.impl.Initialize(args) + + return err +} + func (ds *databasePluginRPCServer) Close(_ interface{}, _ *struct{}) error { ds.impl.Close() return nil diff --git a/builtin/logical/database/path_config_connection.go b/builtin/logical/database/path_config_connection.go index 31f6182817eb..6c0a63a11f26 100644 --- a/builtin/logical/database/path_config_connection.go +++ b/builtin/logical/database/path_config_connection.go @@ -3,6 +3,7 @@ package database import ( "errors" "fmt" + "strings" "time" "github.com/fatih/structs" @@ -67,6 +68,11 @@ func (b *databaseBackend) pathConnectionReset(req *logical.Request, data *framew return logical.ErrorResponse(fmt.Sprintf("Error creating database object: %s", err)), nil } + err = db.Initialize(config.ConnectionDetails) + if err != nil { + return logical.ErrorResponse(fmt.Sprintf("Error creating database object: %s", err)), nil + } + b.connections[name] = db return nil, nil @@ -207,6 +213,11 @@ func (b *databaseBackend) connectionWriteHandler(factory dbs.Factory) framework. } name := data.Get("name").(string) + if name == "" { + return logical.ErrorResponse("Empty name attribute given"), nil + } + + verifyConnection := data.Get("verify_connection").(bool) // Grab the mutex lock b.Lock() @@ -225,6 +236,19 @@ func (b *databaseBackend) connectionWriteHandler(factory dbs.Factory) framework. return logical.ErrorResponse(fmt.Sprintf("Error creating database object: %s", err)), nil } + err := db.Initialize(config.ConnectionDetails) + if err != nil { + if !strings.Contains(err.Error(), "Error Initializing Connection") { + return logical.ErrorResponse(fmt.Sprintf("Error creating database object: %s", err)), nil + + } + + if verifyConnection { + return logical.ErrorResponse(err.Error()), nil + + } + } + b.connections[name] = db } From 143166b1baf135cc899b2b3fcf9d95e35c878d23 Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Tue, 14 Mar 2017 13:11:28 -0700 Subject: [PATCH 023/152] Add a metrics middleware --- .../database/dbs/connectionproducer.go | 10 +- builtin/logical/database/dbs/db.go | 27 +++- .../logical/database/dbs/metricsmiddleware.go | 145 ++++++++++++++++++ builtin/logical/database/dbs/plugin.go | 9 +- 4 files changed, 176 insertions(+), 15 deletions(-) create mode 100644 builtin/logical/database/dbs/metricsmiddleware.go diff --git a/builtin/logical/database/dbs/connectionproducer.go b/builtin/logical/database/dbs/connectionproducer.go index 8d05e5d9e6b4..1e944c7b964b 100644 --- a/builtin/logical/database/dbs/connectionproducer.go +++ b/builtin/logical/database/dbs/connectionproducer.go @@ -24,7 +24,7 @@ var ( ) type ConnectionProducer interface { - Close() + Close() error Initialize(map[string]interface{}) error sync.Locker @@ -97,7 +97,7 @@ func (c *sqlConnectionProducer) connection() (interface{}, error) { return c.db, nil } -func (c *sqlConnectionProducer) Close() { +func (c *sqlConnectionProducer) Close() error { // Grab the write lock c.Lock() defer c.Unlock() @@ -107,6 +107,8 @@ func (c *sqlConnectionProducer) Close() { } c.db = nil + + return nil } type cassandraConnectionProducer struct { @@ -167,7 +169,7 @@ func (c *cassandraConnectionProducer) connection() (interface{}, error) { return session, nil } -func (c *cassandraConnectionProducer) Close() { +func (c *cassandraConnectionProducer) Close() error { // Grab the write lock c.Lock() defer c.Unlock() @@ -177,6 +179,8 @@ func (c *cassandraConnectionProducer) Close() { } c.session = nil + + return nil } func (c *cassandraConnectionProducer) createSession() (*gocql.Session, error) { diff --git a/builtin/logical/database/dbs/db.go b/builtin/logical/database/dbs/db.go index 98443f8f2948..2cc42a731f84 100644 --- a/builtin/logical/database/dbs/db.go +++ b/builtin/logical/database/dbs/db.go @@ -21,6 +21,8 @@ var ( type Factory func(*DatabaseConfig) (DatabaseType, error) func BuiltinFactory(conf *DatabaseConfig) (DatabaseType, error) { + var dbType DatabaseType + switch conf.DatabaseType { case postgreSQLTypeName: connProducer := &sqlConnectionProducer{} @@ -31,10 +33,10 @@ func BuiltinFactory(conf *DatabaseConfig) (DatabaseType, error) { usernameLen: 63, } - return &PostgreSQL{ + dbType = &PostgreSQL{ ConnectionProducer: connProducer, CredentialsProducer: credsProducer, - }, nil + } case mySQLTypeName: connProducer := &sqlConnectionProducer{} @@ -45,10 +47,10 @@ func BuiltinFactory(conf *DatabaseConfig) (DatabaseType, error) { usernameLen: 16, } - return &MySQL{ + dbType = &MySQL{ ConnectionProducer: connProducer, CredentialsProducer: credsProducer, - }, nil + } case cassandraTypeName: connProducer := &cassandraConnectionProducer{} @@ -56,13 +58,22 @@ func BuiltinFactory(conf *DatabaseConfig) (DatabaseType, error) { credsProducer := &cassandraCredentialsProducer{} - return &Cassandra{ + dbType = &Cassandra{ ConnectionProducer: connProducer, CredentialsProducer: credsProducer, - }, nil + } + + default: + return nil, ErrUnsupportedDatabaseType + } + + // Wrap with metrics middleware + dbType = &databaseMetricsMiddleware{ + next: dbType, + typeStr: dbType.Type(), } - return nil, ErrUnsupportedDatabaseType + return dbType, nil } func PluginFactory(conf *DatabaseConfig) (DatabaseType, error) { @@ -89,7 +100,7 @@ type DatabaseType interface { RevokeUser(statements Statements, username string) error Initialize(map[string]interface{}) error - Close() + Close() error CredentialsProducer } diff --git a/builtin/logical/database/dbs/metricsmiddleware.go b/builtin/logical/database/dbs/metricsmiddleware.go new file mode 100644 index 000000000000..61b4bd4ebca9 --- /dev/null +++ b/builtin/logical/database/dbs/metricsmiddleware.go @@ -0,0 +1,145 @@ +package dbs + +import ( + "time" + + metrics "github.com/armon/go-metrics" +) + +type databaseMetricsMiddleware struct { + next DatabaseType + + typeStr string +} + +func (mw *databaseMetricsMiddleware) Type() string { + return mw.next.Type() +} + +func (mw *databaseMetricsMiddleware) CreateUser(statements Statements, username, password, expiration string) (err error) { + defer func(now time.Time) { + metrics.MeasureSince([]string{"database", "CreateUser"}, now) + metrics.MeasureSince([]string{"database", mw.typeStr, "CreateUser"}, now) + + if err != nil { + metrics.IncrCounter([]string{"database", "CreateUser", "error"}, 1) + metrics.IncrCounter([]string{"database", mw.typeStr, "CreateUser", "error"}, 1) + } + }(time.Now()) + + metrics.IncrCounter([]string{"database", "CreateUser"}, 1) + metrics.IncrCounter([]string{"database", mw.typeStr, "CreateUser"}, 1) + return mw.next.CreateUser(statements, username, password, expiration) +} + +func (mw *databaseMetricsMiddleware) RenewUser(statements Statements, username, expiration string) (err error) { + defer func(now time.Time) { + metrics.MeasureSince([]string{"database", "RenewUser"}, now) + metrics.MeasureSince([]string{"database", mw.typeStr, "RenewUser"}, now) + + if err != nil { + metrics.IncrCounter([]string{"database", "RenewUser", "error"}, 1) + metrics.IncrCounter([]string{"database", mw.typeStr, "RenewUser", "error"}, 1) + } + }(time.Now()) + + metrics.IncrCounter([]string{"database", "RenewUser"}, 1) + metrics.IncrCounter([]string{"database", mw.typeStr, "RenewUser"}, 1) + return mw.next.RenewUser(statements, username, expiration) +} + +func (mw *databaseMetricsMiddleware) RevokeUser(statements Statements, username string) (err error) { + defer func(now time.Time) { + metrics.MeasureSince([]string{"database", "RevokeUser"}, now) + metrics.MeasureSince([]string{"database", mw.typeStr, "RevokeUser"}, now) + + if err != nil { + metrics.IncrCounter([]string{"database", "RevokeUser", "error"}, 1) + metrics.IncrCounter([]string{"database", mw.typeStr, "RevokeUser", "error"}, 1) + } + }(time.Now()) + + metrics.IncrCounter([]string{"database", "RevokeUser"}, 1) + metrics.IncrCounter([]string{"database", mw.typeStr, "RevokeUser"}, 1) + return mw.next.RevokeUser(statements, username) +} + +func (mw *databaseMetricsMiddleware) Initialize(conf map[string]interface{}) (err error) { + defer func(now time.Time) { + metrics.MeasureSince([]string{"database", "Initialize"}, now) + metrics.MeasureSince([]string{"database", mw.typeStr, "Initialize"}, now) + + if err != nil { + metrics.IncrCounter([]string{"database", "Initialize", "error"}, 1) + metrics.IncrCounter([]string{"database", mw.typeStr, "Initialize", "error"}, 1) + } + }(time.Now()) + + metrics.IncrCounter([]string{"database", "Initialize"}, 1) + metrics.IncrCounter([]string{"database", mw.typeStr, "Initialize"}, 1) + return mw.next.Initialize(conf) +} + +func (mw *databaseMetricsMiddleware) Close() (err error) { + defer func(now time.Time) { + metrics.MeasureSince([]string{"database", "Close"}, now) + metrics.MeasureSince([]string{"database", mw.typeStr, "Close"}, now) + + if err != nil { + metrics.IncrCounter([]string{"database", "Close", "error"}, 1) + metrics.IncrCounter([]string{"database", mw.typeStr, "Close", "error"}, 1) + } + }(time.Now()) + + metrics.IncrCounter([]string{"database", "Close"}, 1) + metrics.IncrCounter([]string{"database", mw.typeStr, "Close"}, 1) + return mw.next.Close() +} + +func (mw *databaseMetricsMiddleware) GenerateUsername(displayName string) (_ string, err error) { + defer func(now time.Time) { + metrics.MeasureSince([]string{"database", "GenerateUsername"}, now) + metrics.MeasureSince([]string{"database", mw.typeStr, "GenerateUsername"}, now) + + if err != nil { + metrics.IncrCounter([]string{"database", "GenerateUsername", "error"}, 1) + metrics.IncrCounter([]string{"database", mw.typeStr, "GenerateUsername", "error"}, 1) + } + }(time.Now()) + + metrics.IncrCounter([]string{"database", "GenerateUsername"}, 1) + metrics.IncrCounter([]string{"database", mw.typeStr, "GenerateUsername"}, 1) + return mw.next.GenerateUsername(displayName) +} + +func (mw *databaseMetricsMiddleware) GeneratePassword() (_ string, err error) { + defer func(now time.Time) { + metrics.MeasureSince([]string{"database", "GeneratePassword"}, now) + metrics.MeasureSince([]string{"database", mw.typeStr, "GeneratePassword"}, now) + + if err != nil { + metrics.IncrCounter([]string{"database", "GeneratePassword", "error"}, 1) + metrics.IncrCounter([]string{"database", mw.typeStr, "GeneratePassword", "error"}, 1) + } + }(time.Now()) + + metrics.IncrCounter([]string{"database", "GeneratePassword"}, 1) + metrics.IncrCounter([]string{"database", mw.typeStr, "GeneratePassword"}, 1) + return mw.next.GeneratePassword() +} + +func (mw *databaseMetricsMiddleware) GenerateExpiration(duration time.Duration) (_ string, err error) { + defer func(now time.Time) { + metrics.MeasureSince([]string{"database", "GenerateExpiration"}, now) + metrics.MeasureSince([]string{"database", mw.typeStr, "GenerateExpiration"}, now) + + if err != nil { + metrics.IncrCounter([]string{"database", "GenerateExpiration", "error"}, 1) + metrics.IncrCounter([]string{"database", mw.typeStr, "GenerateExpiration", "error"}, 1) + } + }(time.Now()) + + metrics.IncrCounter([]string{"database", "GenerateExpiration"}, 1) + metrics.IncrCounter([]string{"database", mw.typeStr, "GenerateExpiration"}, 1) + return mw.next.GenerateExpiration(duration) +} diff --git a/builtin/logical/database/dbs/plugin.go b/builtin/logical/database/dbs/plugin.go index b244a33fc667..7b2b18e009f7 100644 --- a/builtin/logical/database/dbs/plugin.go +++ b/builtin/logical/database/dbs/plugin.go @@ -38,10 +38,11 @@ type DatabasePluginClient struct { *databasePluginRPCClient } -func (dc *DatabasePluginClient) Close() { - dc.databasePluginRPCClient.Close() - +func (dc *DatabasePluginClient) Close() error { + err := dc.databasePluginRPCClient.Close() dc.client.Kill() + + return err } func newPluginClient(command, checksum string) (DatabaseType, error) { @@ -179,7 +180,7 @@ type databasePluginRPCServer struct { } func (ds *databasePluginRPCServer) Type(_ struct{}, resp *string) error { - *resp = "string" + *resp = ds.impl.Type() return nil } From a6ae4bd3564bc147c7a08a248bfb3ebf72adcb83 Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Tue, 14 Mar 2017 13:12:47 -0700 Subject: [PATCH 024/152] wrap plugin database type with metrics middleware --- builtin/logical/database/dbs/db.go | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/builtin/logical/database/dbs/db.go b/builtin/logical/database/dbs/db.go index 2cc42a731f84..3b10db46461a 100644 --- a/builtin/logical/database/dbs/db.go +++ b/builtin/logical/database/dbs/db.go @@ -90,6 +90,12 @@ func PluginFactory(conf *DatabaseConfig) (DatabaseType, error) { return nil, err } + // Wrap with metrics middleware + db = &databaseMetricsMiddleware{ + next: db, + typeStr: db.Type(), + } + return db, nil } From 5b05f62fa314368ce3ee85a118d37af7100766d6 Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Wed, 15 Mar 2017 17:14:48 -0700 Subject: [PATCH 025/152] Work on TLS communication over plugins --- builtin/logical/database/dbs/db.go | 10 +- builtin/logical/database/dbs/plugin.go | 269 +++++++++++++++++- .../database/path_config_connection.go | 4 +- logical/system_view.go | 9 + vault/dynamic_system_view.go | 27 ++ 5 files changed, 311 insertions(+), 8 deletions(-) diff --git a/builtin/logical/database/dbs/db.go b/builtin/logical/database/dbs/db.go index 3b10db46461a..b681de360bc4 100644 --- a/builtin/logical/database/dbs/db.go +++ b/builtin/logical/database/dbs/db.go @@ -5,6 +5,8 @@ import ( "fmt" "strings" "time" + + "github.com/hashicorp/vault/logical" ) const ( @@ -18,9 +20,9 @@ var ( ErrUnsupportedDatabaseType = errors.New("Unsupported database type") ) -type Factory func(*DatabaseConfig) (DatabaseType, error) +type Factory func(*DatabaseConfig, logical.SystemView) (DatabaseType, error) -func BuiltinFactory(conf *DatabaseConfig) (DatabaseType, error) { +func BuiltinFactory(conf *DatabaseConfig, sys logical.SystemView) (DatabaseType, error) { var dbType DatabaseType switch conf.DatabaseType { @@ -76,7 +78,7 @@ func BuiltinFactory(conf *DatabaseConfig) (DatabaseType, error) { return dbType, nil } -func PluginFactory(conf *DatabaseConfig) (DatabaseType, error) { +func PluginFactory(conf *DatabaseConfig, sys logical.SystemView) (DatabaseType, error) { if conf.PluginCommand == "" { return nil, errors.New("ERROR") } @@ -85,7 +87,7 @@ func PluginFactory(conf *DatabaseConfig) (DatabaseType, error) { return nil, errors.New("ERROR") } - db, err := newPluginClient(conf.PluginCommand, conf.PluginChecksum) + db, err := newPluginClient(sys, conf.PluginCommand, conf.PluginChecksum) if err != nil { return nil, err } diff --git a/builtin/logical/database/dbs/plugin.go b/builtin/logical/database/dbs/plugin.go index 7b2b18e009f7..e4f5359a744b 100644 --- a/builtin/logical/database/dbs/plugin.go +++ b/builtin/logical/database/dbs/plugin.go @@ -1,12 +1,31 @@ package dbs import ( + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "crypto/tls" + "crypto/x509" + "crypto/x509/pkix" + "encoding/base64" + "errors" + "fmt" + "math/big" + mathrand "math/rand" "net/rpc" + "net/url" + "os" "os/exec" + "strings" "sync" "time" + "github.com/SermoDigital/jose/jws" + "github.com/hashicorp/errwrap" "github.com/hashicorp/go-plugin" + uuid "github.com/hashicorp/go-uuid" + "github.com/hashicorp/vault/api" + "github.com/hashicorp/vault/logical" ) // handshakeConfigs are used to just do a basic handshake between @@ -45,16 +64,155 @@ func (dc *DatabasePluginClient) Close() error { return err } -func newPluginClient(command, checksum string) (DatabaseType, error) { +func generateX509Cert() ([]byte, *x509.Certificate, *ecdsa.PrivateKey, error) { + key, err := ecdsa.GenerateKey(elliptic.P521(), rand.Reader) + if err != nil { + // c.logger.Error("core: failed to generate replicated cluster signing key", "error", err) + return nil, nil, nil, err + } + + //c.logger.Trace("core: generating replicated cluster certificate") + + host, err := uuid.GenerateUUID() + if err != nil { + return nil, nil, nil, err + } + host = "localhost" + template := &x509.Certificate{ + Subject: pkix.Name{ + CommonName: host, + }, + DNSNames: []string{host}, + ExtKeyUsage: []x509.ExtKeyUsage{ + x509.ExtKeyUsageServerAuth, + x509.ExtKeyUsageClientAuth, + }, + KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment | x509.KeyUsageKeyAgreement | x509.KeyUsageCertSign, + SerialNumber: big.NewInt(mathrand.Int63()), + NotBefore: time.Now().Add(-30 * time.Second), + // 30 years of single-active uptime ought to be enough for anybody + NotAfter: time.Now().Add(262980 * time.Hour), + BasicConstraintsValid: true, + IsCA: true, + } + + certBytes, err := x509.CreateCertificate(rand.Reader, template, template, key.Public(), key) + if err != nil { + // c.logger.Error("core: error generating self-signed cert for replication", "error", err) + return nil, nil, nil, fmt.Errorf("unable to generate replicated cluster certificate: %v", err) + } + + caCert, err := x509.ParseCertificate(certBytes) + if err != nil { + // c.logger.Error("core: error parsing replicated self-signed cert", "error", err) + return nil, nil, nil, fmt.Errorf("error parsing generated replication certificate: %v", err) + } + + return certBytes, caCert, key, nil +} + +func generateClientCert(CACert *x509.Certificate, CAKey *ecdsa.PrivateKey) ([]byte, *x509.Certificate, []byte, error) { + host, err := uuid.GenerateUUID() + if err != nil { + return nil, nil, nil, err + } + host = "localhost" + template := &x509.Certificate{ + Subject: pkix.Name{ + CommonName: host, + }, + DNSNames: []string{host}, + ExtKeyUsage: []x509.ExtKeyUsage{ + x509.ExtKeyUsageClientAuth, + x509.ExtKeyUsageServerAuth, + }, + KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment | x509.KeyUsageKeyAgreement, + SerialNumber: big.NewInt(mathrand.Int63()), + NotBefore: time.Now().Add(-30 * time.Second), + NotAfter: time.Now().Add(262980 * time.Hour), + } + + clientKey, err := ecdsa.GenerateKey(elliptic.P521(), rand.Reader) + if err != nil { + return nil, nil, nil, errwrap.Wrapf("error generating client key: {{err}}", err) + } + + certBytes, err := x509.CreateCertificate(rand.Reader, template, CACert, clientKey.Public(), CAKey) + if err != nil { + return nil, nil, nil, errwrap.Wrapf("unable to generate client certificate: {{err}}", err) + } + + clientCert, err := x509.ParseCertificate(certBytes) + if err != nil { + // c.logger.Error("core: error parsing replicated self-signed cert", "error", err) + return nil, nil, nil, fmt.Errorf("error parsing generated replication certificate: %v", err) + } + + keyBytes, err := x509.MarshalECPrivateKey(clientKey) + if err != nil { + return nil, nil, nil, err + } + + return certBytes, clientCert, keyBytes, nil +} + +func newPluginClient(sys logical.SystemView, command, checksum string) (DatabaseType, error) { // pluginMap is the map of plugins we can dispense. var pluginMap = map[string]plugin.Plugin{ "database": new(DatabasePlugin), } + CACertBytes, CACert, CAKey, err := generateX509Cert() + if err != nil { + return nil, err + } + + clientCertBytes, clientCert, clientKey, err := generateClientCert(CACert, CAKey) + if err != nil { + return nil, err + } + + /* serverCert, serverKey, err := generateClientCert(CACert, CAKey) + if err != nil { + return nil, err + }*/ + serverKey, err := x509.MarshalECPrivateKey(CAKey) + if err != nil { + return nil, err + } + cert := tls.Certificate{ + Certificate: [][]byte{clientCertBytes, CACertBytes}, + PrivateKey: clientKey, + Leaf: clientCert, + } + + clientCertPool := x509.NewCertPool() + clientCertPool.AddCert(CACert) + + tlsConfig := &tls.Config{ + Certificates: []tls.Certificate{cert}, + RootCAs: clientCertPool, + ClientCAs: clientCertPool, + ServerName: CACert.Subject.CommonName, + MinVersion: tls.VersionTLS12, + } + + tlsConfig.BuildNameToCertificate() + + wrapToken, err := sys.ResponseWrapData(map[string]interface{}{ + "CACert": CACertBytes, + "ServerCert": CACertBytes, + "ServerKey": serverKey, + }, time.Second*10, true) + + cmd := exec.Command(command) + cmd.Env = append(cmd.Env, fmt.Sprintf("VAULT_WRAP_TOKEN=%s", wrapToken)) + client := plugin.NewClient(&plugin.ClientConfig{ HandshakeConfig: handshakeConfig, Plugins: pluginMap, - Cmd: exec.Command(command), + Cmd: cmd, + TLSConfig: tlsConfig, }) // Connect via RPC @@ -92,9 +250,116 @@ func NewPluginServer(db DatabaseType) { plugin.Serve(&plugin.ServeConfig{ HandshakeConfig: handshakeConfig, Plugins: pluginMap, + TLSProvider: VaultPluginTLSProvider, }) } +func VaultPluginTLSProvider() (*tls.Config, error) { + unwrapToken := os.Getenv("VAULT_WRAP_TOKEN") + if strings.Count(unwrapToken, ".") != 2 { + return nil, errors.New("Could not parse unwraptoken") + } + + wt, err := jws.ParseJWT([]byte(unwrapToken)) + if err != nil { + return nil, errors.New(fmt.Sprintf("error decoding token: %s", err)) + } + if wt == nil { + return nil, errors.New("nil decoded token") + } + + addrRaw := wt.Claims().Get("addr") + if addrRaw == nil { + return nil, errors.New("decoded token does not contain primary cluster address") + } + vaultAddr, ok := addrRaw.(string) + if !ok { + return nil, errors.New("decoded token's address not valid") + } + if vaultAddr == "" { + return nil, errors.New(`no address for the vault found`) + } + + // Sanity check the value + if _, err := url.Parse(vaultAddr); err != nil { + return nil, errors.New(fmt.Sprintf("error parsing the vault address: %s", err)) + } + + clientConf := api.DefaultConfig() + clientConf.Address = vaultAddr + client, err := api.NewClient(clientConf) + if err != nil { + return nil, errwrap.Wrapf("error during token unwrap request: {{err}}", err) + } + + secret, err := client.Logical().Unwrap(unwrapToken) + if err != nil { + return nil, errwrap.Wrapf("error during token unwrap request: {{err}}", err) + } + + CABytesRaw, ok := secret.Data["CACert"].(string) + if !ok { + return nil, errors.New("error unmarshalling certificate") + } + + CABytes, err := base64.StdEncoding.DecodeString(CABytesRaw) + if err != nil { + return nil, fmt.Errorf("error parsing certificate: %v", err) + } + + CACert, err := x509.ParseCertificate(CABytes) + if err != nil { + return nil, fmt.Errorf("error parsing certificate: %v", err) + } + + serverCertBytesRaw, ok := secret.Data["ServerCert"].(string) + if !ok { + return nil, errors.New("error unmarshalling certificate") + } + + serverCertBytes, err := base64.StdEncoding.DecodeString(serverCertBytesRaw) + if err != nil { + return nil, fmt.Errorf("error parsing certificate: %v", err) + } + + serverCert, err := x509.ParseCertificate(serverCertBytes) + if err != nil { + return nil, fmt.Errorf("error parsing certificate: %v", err) + } + + serverKeyRaw, ok := secret.Data["ServerKey"].(string) + if !ok { + return nil, errors.New("error unmarshalling certificate") + } + + serverKey, err := base64.StdEncoding.DecodeString(serverKeyRaw) + if err != nil { + return nil, fmt.Errorf("error parsing certificate: %v", err) + } + + caCertPool := x509.NewCertPool() + caCertPool.AddCert(CACert) + + cert := tls.Certificate{ + Certificate: [][]byte{serverCertBytes}, + PrivateKey: serverKey, + Leaf: serverCert, + } + + // Setup TLS config + tlsConfig := &tls.Config{ + ClientCAs: caCertPool, + RootCAs: caCertPool, + ClientAuth: tls.RequireAndVerifyClientCert, + // TLS 1.2 minimum + MinVersion: tls.VersionTLS12, + Certificates: []tls.Certificate{cert}, + } + tlsConfig.BuildNameToCertificate() + + return tlsConfig, nil +} + // ---- RPC client domain ---- type databasePluginRPCClient struct { diff --git a/builtin/logical/database/path_config_connection.go b/builtin/logical/database/path_config_connection.go index 6c0a63a11f26..0a99ad196263 100644 --- a/builtin/logical/database/path_config_connection.go +++ b/builtin/logical/database/path_config_connection.go @@ -63,7 +63,7 @@ func (b *databaseBackend) pathConnectionReset(req *logical.Request, data *framew factory := config.GetFactory() - db, err = factory(&config) + db, err = factory(&config, b.System()) if err != nil { return logical.ErrorResponse(fmt.Sprintf("Error creating database object: %s", err)), nil } @@ -231,7 +231,7 @@ func (b *databaseBackend) connectionWriteHandler(factory dbs.Factory) framework. return logical.ErrorResponse("Can not change type of existing connection."), nil } } else { - db, err = factory(config) + db, err = factory(config, b.System()) if err != nil { return logical.ErrorResponse(fmt.Sprintf("Error creating database object: %s", err)), nil } diff --git a/logical/system_view.go b/logical/system_view.go index d769397dfcc9..56254b33a17a 100644 --- a/logical/system_view.go +++ b/logical/system_view.go @@ -1,6 +1,7 @@ package logical import ( + "errors" "time" "github.com/hashicorp/vault/helper/consts" @@ -37,6 +38,10 @@ type SystemView interface { // ReplicationState indicates the state of cluster replication ReplicationState() consts.ReplicationState + + // ResponseWrapData wraps the given data in a cubbyhole and returns the + // token used to unwrap. + ResponseWrapData(data map[string]interface{}, ttl time.Duration, jwt bool) (string, error) } type StaticSystemView struct { @@ -72,3 +77,7 @@ func (d StaticSystemView) CachingDisabled() bool { func (d StaticSystemView) ReplicationState() consts.ReplicationState { return d.ReplicationStateVal } + +func (d StaticSystemView) ResponseWrapData(data map[string]interface{}, ttl time.Duration, jwt bool) (string, error) { + return "", errors.New("ResponseWrapData is not implimented in StaticSystemView") +} diff --git a/vault/dynamic_system_view.go b/vault/dynamic_system_view.go index 32c906fae602..4c6807ace930 100644 --- a/vault/dynamic_system_view.go +++ b/vault/dynamic_system_view.go @@ -87,3 +87,30 @@ func (d dynamicSystemView) ReplicationState() consts.ReplicationState { d.core.clusterParamsLock.RUnlock() return state } + +// ResponseWrapData wraps the given data in a cubbyhole and returns the +// token used to unwrap. +func (d dynamicSystemView) ResponseWrapData(data map[string]interface{}, ttl time.Duration, jwt bool) (string, error) { + req := &logical.Request{ + Operation: logical.CreateOperation, + Path: "sys/init", + } + + resp := &logical.Response{ + WrapInfo: &logical.ResponseWrapInfo{ + TTL: ttl, + }, + Data: data, + } + + if jwt { + resp.WrapInfo.Format = "jwt" + } + + _, err := d.core.wrapInCubbyhole(req, resp) + if err != nil { + return "", err + } + + return resp.WrapInfo.Token, nil +} From 3890f194a48ce1d3bc7c9fc9cec9f89834fab73b Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Thu, 16 Mar 2017 11:55:21 -0700 Subject: [PATCH 026/152] Break tls code into helper library --- builtin/logical/database/dbs/plugin.go | 220 +------------------------ helper/pluginutil/tls.go | 218 ++++++++++++++++++++++++ 2 files changed, 222 insertions(+), 216 deletions(-) create mode 100644 helper/pluginutil/tls.go diff --git a/builtin/logical/database/dbs/plugin.go b/builtin/logical/database/dbs/plugin.go index e4f5359a744b..b4649fc7f419 100644 --- a/builtin/logical/database/dbs/plugin.go +++ b/builtin/logical/database/dbs/plugin.go @@ -1,30 +1,16 @@ package dbs import ( - "crypto/ecdsa" - "crypto/elliptic" - "crypto/rand" "crypto/tls" "crypto/x509" - "crypto/x509/pkix" - "encoding/base64" - "errors" "fmt" - "math/big" - mathrand "math/rand" "net/rpc" - "net/url" - "os" "os/exec" - "strings" "sync" "time" - "github.com/SermoDigital/jose/jws" - "github.com/hashicorp/errwrap" "github.com/hashicorp/go-plugin" - uuid "github.com/hashicorp/go-uuid" - "github.com/hashicorp/vault/api" + "github.com/hashicorp/vault/helper/pluginutil" "github.com/hashicorp/vault/logical" ) @@ -64,110 +50,18 @@ func (dc *DatabasePluginClient) Close() error { return err } -func generateX509Cert() ([]byte, *x509.Certificate, *ecdsa.PrivateKey, error) { - key, err := ecdsa.GenerateKey(elliptic.P521(), rand.Reader) - if err != nil { - // c.logger.Error("core: failed to generate replicated cluster signing key", "error", err) - return nil, nil, nil, err - } - - //c.logger.Trace("core: generating replicated cluster certificate") - - host, err := uuid.GenerateUUID() - if err != nil { - return nil, nil, nil, err - } - host = "localhost" - template := &x509.Certificate{ - Subject: pkix.Name{ - CommonName: host, - }, - DNSNames: []string{host}, - ExtKeyUsage: []x509.ExtKeyUsage{ - x509.ExtKeyUsageServerAuth, - x509.ExtKeyUsageClientAuth, - }, - KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment | x509.KeyUsageKeyAgreement | x509.KeyUsageCertSign, - SerialNumber: big.NewInt(mathrand.Int63()), - NotBefore: time.Now().Add(-30 * time.Second), - // 30 years of single-active uptime ought to be enough for anybody - NotAfter: time.Now().Add(262980 * time.Hour), - BasicConstraintsValid: true, - IsCA: true, - } - - certBytes, err := x509.CreateCertificate(rand.Reader, template, template, key.Public(), key) - if err != nil { - // c.logger.Error("core: error generating self-signed cert for replication", "error", err) - return nil, nil, nil, fmt.Errorf("unable to generate replicated cluster certificate: %v", err) - } - - caCert, err := x509.ParseCertificate(certBytes) - if err != nil { - // c.logger.Error("core: error parsing replicated self-signed cert", "error", err) - return nil, nil, nil, fmt.Errorf("error parsing generated replication certificate: %v", err) - } - - return certBytes, caCert, key, nil -} - -func generateClientCert(CACert *x509.Certificate, CAKey *ecdsa.PrivateKey) ([]byte, *x509.Certificate, []byte, error) { - host, err := uuid.GenerateUUID() - if err != nil { - return nil, nil, nil, err - } - host = "localhost" - template := &x509.Certificate{ - Subject: pkix.Name{ - CommonName: host, - }, - DNSNames: []string{host}, - ExtKeyUsage: []x509.ExtKeyUsage{ - x509.ExtKeyUsageClientAuth, - x509.ExtKeyUsageServerAuth, - }, - KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment | x509.KeyUsageKeyAgreement, - SerialNumber: big.NewInt(mathrand.Int63()), - NotBefore: time.Now().Add(-30 * time.Second), - NotAfter: time.Now().Add(262980 * time.Hour), - } - - clientKey, err := ecdsa.GenerateKey(elliptic.P521(), rand.Reader) - if err != nil { - return nil, nil, nil, errwrap.Wrapf("error generating client key: {{err}}", err) - } - - certBytes, err := x509.CreateCertificate(rand.Reader, template, CACert, clientKey.Public(), CAKey) - if err != nil { - return nil, nil, nil, errwrap.Wrapf("unable to generate client certificate: {{err}}", err) - } - - clientCert, err := x509.ParseCertificate(certBytes) - if err != nil { - // c.logger.Error("core: error parsing replicated self-signed cert", "error", err) - return nil, nil, nil, fmt.Errorf("error parsing generated replication certificate: %v", err) - } - - keyBytes, err := x509.MarshalECPrivateKey(clientKey) - if err != nil { - return nil, nil, nil, err - } - - return certBytes, clientCert, keyBytes, nil -} - func newPluginClient(sys logical.SystemView, command, checksum string) (DatabaseType, error) { // pluginMap is the map of plugins we can dispense. var pluginMap = map[string]plugin.Plugin{ "database": new(DatabasePlugin), } - CACertBytes, CACert, CAKey, err := generateX509Cert() + CACertBytes, CACert, CAKey, err := pluginutil.GenerateX509Cert() if err != nil { return nil, err } - clientCertBytes, clientCert, clientKey, err := generateClientCert(CACert, CAKey) + clientCertBytes, clientCert, clientKey, err := pluginutil.GenerateClientCert(CACert, CAKey) if err != nil { return nil, err } @@ -250,116 +144,10 @@ func NewPluginServer(db DatabaseType) { plugin.Serve(&plugin.ServeConfig{ HandshakeConfig: handshakeConfig, Plugins: pluginMap, - TLSProvider: VaultPluginTLSProvider, + TLSProvider: pluginutil.VaultPluginTLSProvider, }) } -func VaultPluginTLSProvider() (*tls.Config, error) { - unwrapToken := os.Getenv("VAULT_WRAP_TOKEN") - if strings.Count(unwrapToken, ".") != 2 { - return nil, errors.New("Could not parse unwraptoken") - } - - wt, err := jws.ParseJWT([]byte(unwrapToken)) - if err != nil { - return nil, errors.New(fmt.Sprintf("error decoding token: %s", err)) - } - if wt == nil { - return nil, errors.New("nil decoded token") - } - - addrRaw := wt.Claims().Get("addr") - if addrRaw == nil { - return nil, errors.New("decoded token does not contain primary cluster address") - } - vaultAddr, ok := addrRaw.(string) - if !ok { - return nil, errors.New("decoded token's address not valid") - } - if vaultAddr == "" { - return nil, errors.New(`no address for the vault found`) - } - - // Sanity check the value - if _, err := url.Parse(vaultAddr); err != nil { - return nil, errors.New(fmt.Sprintf("error parsing the vault address: %s", err)) - } - - clientConf := api.DefaultConfig() - clientConf.Address = vaultAddr - client, err := api.NewClient(clientConf) - if err != nil { - return nil, errwrap.Wrapf("error during token unwrap request: {{err}}", err) - } - - secret, err := client.Logical().Unwrap(unwrapToken) - if err != nil { - return nil, errwrap.Wrapf("error during token unwrap request: {{err}}", err) - } - - CABytesRaw, ok := secret.Data["CACert"].(string) - if !ok { - return nil, errors.New("error unmarshalling certificate") - } - - CABytes, err := base64.StdEncoding.DecodeString(CABytesRaw) - if err != nil { - return nil, fmt.Errorf("error parsing certificate: %v", err) - } - - CACert, err := x509.ParseCertificate(CABytes) - if err != nil { - return nil, fmt.Errorf("error parsing certificate: %v", err) - } - - serverCertBytesRaw, ok := secret.Data["ServerCert"].(string) - if !ok { - return nil, errors.New("error unmarshalling certificate") - } - - serverCertBytes, err := base64.StdEncoding.DecodeString(serverCertBytesRaw) - if err != nil { - return nil, fmt.Errorf("error parsing certificate: %v", err) - } - - serverCert, err := x509.ParseCertificate(serverCertBytes) - if err != nil { - return nil, fmt.Errorf("error parsing certificate: %v", err) - } - - serverKeyRaw, ok := secret.Data["ServerKey"].(string) - if !ok { - return nil, errors.New("error unmarshalling certificate") - } - - serverKey, err := base64.StdEncoding.DecodeString(serverKeyRaw) - if err != nil { - return nil, fmt.Errorf("error parsing certificate: %v", err) - } - - caCertPool := x509.NewCertPool() - caCertPool.AddCert(CACert) - - cert := tls.Certificate{ - Certificate: [][]byte{serverCertBytes}, - PrivateKey: serverKey, - Leaf: serverCert, - } - - // Setup TLS config - tlsConfig := &tls.Config{ - ClientCAs: caCertPool, - RootCAs: caCertPool, - ClientAuth: tls.RequireAndVerifyClientCert, - // TLS 1.2 minimum - MinVersion: tls.VersionTLS12, - Certificates: []tls.Certificate{cert}, - } - tlsConfig.BuildNameToCertificate() - - return tlsConfig, nil -} - // ---- RPC client domain ---- type databasePluginRPCClient struct { diff --git a/helper/pluginutil/tls.go b/helper/pluginutil/tls.go new file mode 100644 index 000000000000..55f27a3881cc --- /dev/null +++ b/helper/pluginutil/tls.go @@ -0,0 +1,218 @@ +package pluginutil + +import ( + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "crypto/tls" + "crypto/x509" + "crypto/x509/pkix" + "encoding/base64" + "errors" + "fmt" + "math/big" + mathrand "math/rand" + "net/url" + "os" + "strings" + "time" + + "github.com/SermoDigital/jose/jws" + "github.com/hashicorp/errwrap" + uuid "github.com/hashicorp/go-uuid" + "github.com/hashicorp/vault/api" +) + +func GenerateX509Cert() ([]byte, *x509.Certificate, *ecdsa.PrivateKey, error) { + key, err := ecdsa.GenerateKey(elliptic.P521(), rand.Reader) + if err != nil { + return nil, nil, nil, err + } + + host, err := uuid.GenerateUUID() + if err != nil { + return nil, nil, nil, err + } + host = "localhost" + template := &x509.Certificate{ + Subject: pkix.Name{ + CommonName: host, + }, + DNSNames: []string{host}, + ExtKeyUsage: []x509.ExtKeyUsage{ + x509.ExtKeyUsageServerAuth, + x509.ExtKeyUsageClientAuth, + }, + KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment | x509.KeyUsageKeyAgreement | x509.KeyUsageCertSign, + SerialNumber: big.NewInt(mathrand.Int63()), + NotBefore: time.Now().Add(-30 * time.Second), + // 30 years of single-active uptime ought to be enough for anybody + NotAfter: time.Now().Add(262980 * time.Hour), + BasicConstraintsValid: true, + IsCA: true, + } + + certBytes, err := x509.CreateCertificate(rand.Reader, template, template, key.Public(), key) + if err != nil { + return nil, nil, nil, fmt.Errorf("unable to generate replicated cluster certificate: %v", err) + } + + caCert, err := x509.ParseCertificate(certBytes) + if err != nil { + return nil, nil, nil, fmt.Errorf("error parsing generated replication certificate: %v", err) + } + + return certBytes, caCert, key, nil +} + +func GenerateClientCert(CACert *x509.Certificate, CAKey *ecdsa.PrivateKey) ([]byte, *x509.Certificate, []byte, error) { + host, err := uuid.GenerateUUID() + if err != nil { + return nil, nil, nil, err + } + host = "localhost" + template := &x509.Certificate{ + Subject: pkix.Name{ + CommonName: host, + }, + DNSNames: []string{host}, + ExtKeyUsage: []x509.ExtKeyUsage{ + x509.ExtKeyUsageClientAuth, + x509.ExtKeyUsageServerAuth, + }, + KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment | x509.KeyUsageKeyAgreement, + SerialNumber: big.NewInt(mathrand.Int63()), + NotBefore: time.Now().Add(-30 * time.Second), + NotAfter: time.Now().Add(262980 * time.Hour), + } + + clientKey, err := ecdsa.GenerateKey(elliptic.P521(), rand.Reader) + if err != nil { + return nil, nil, nil, errwrap.Wrapf("error generating client key: {{err}}", err) + } + + certBytes, err := x509.CreateCertificate(rand.Reader, template, CACert, clientKey.Public(), CAKey) + if err != nil { + return nil, nil, nil, errwrap.Wrapf("unable to generate client certificate: {{err}}", err) + } + + clientCert, err := x509.ParseCertificate(certBytes) + if err != nil { + return nil, nil, nil, fmt.Errorf("error parsing generated replication certificate: %v", err) + } + + keyBytes, err := x509.MarshalECPrivateKey(clientKey) + if err != nil { + return nil, nil, nil, err + } + + return certBytes, clientCert, keyBytes, nil +} + +// VaultPluginTLSProvider is run inside a plugin and retrives the response +// wrapped TLS certificate from vault. It returns a configured tlsConfig. +func VaultPluginTLSProvider() (*tls.Config, error) { + unwrapToken := os.Getenv("VAULT_WRAP_TOKEN") + if strings.Count(unwrapToken, ".") != 2 { + return nil, errors.New("Could not parse unwraptoken") + } + + wt, err := jws.ParseJWT([]byte(unwrapToken)) + if err != nil { + return nil, errors.New(fmt.Sprintf("error decoding token: %s", err)) + } + if wt == nil { + return nil, errors.New("nil decoded token") + } + + addrRaw := wt.Claims().Get("addr") + if addrRaw == nil { + return nil, errors.New("decoded token does not contain primary cluster address") + } + vaultAddr, ok := addrRaw.(string) + if !ok { + return nil, errors.New("decoded token's address not valid") + } + if vaultAddr == "" { + return nil, errors.New(`no address for the vault found`) + } + + // Sanity check the value + if _, err := url.Parse(vaultAddr); err != nil { + return nil, errors.New(fmt.Sprintf("error parsing the vault address: %s", err)) + } + + clientConf := api.DefaultConfig() + clientConf.Address = vaultAddr + client, err := api.NewClient(clientConf) + if err != nil { + return nil, errwrap.Wrapf("error during token unwrap request: {{err}}", err) + } + + secret, err := client.Logical().Unwrap(unwrapToken) + if err != nil { + return nil, errwrap.Wrapf("error during token unwrap request: {{err}}", err) + } + + CABytesRaw, ok := secret.Data["CACert"].(string) + if !ok { + return nil, errors.New("error unmarshalling certificate") + } + + CABytes, err := base64.StdEncoding.DecodeString(CABytesRaw) + if err != nil { + return nil, fmt.Errorf("error parsing certificate: %v", err) + } + + CACert, err := x509.ParseCertificate(CABytes) + if err != nil { + return nil, fmt.Errorf("error parsing certificate: %v", err) + } + + serverCertBytesRaw, ok := secret.Data["ServerCert"].(string) + if !ok { + return nil, errors.New("error unmarshalling certificate") + } + + serverCertBytes, err := base64.StdEncoding.DecodeString(serverCertBytesRaw) + if err != nil { + return nil, fmt.Errorf("error parsing certificate: %v", err) + } + + serverCert, err := x509.ParseCertificate(serverCertBytes) + if err != nil { + return nil, fmt.Errorf("error parsing certificate: %v", err) + } + + serverKeyRaw, ok := secret.Data["ServerKey"].(string) + if !ok { + return nil, errors.New("error unmarshalling certificate") + } + + serverKey, err := base64.StdEncoding.DecodeString(serverKeyRaw) + if err != nil { + return nil, fmt.Errorf("error parsing certificate: %v", err) + } + + caCertPool := x509.NewCertPool() + caCertPool.AddCert(CACert) + + cert := tls.Certificate{ + Certificate: [][]byte{serverCertBytes}, + PrivateKey: serverKey, + Leaf: serverCert, + } + + // Setup TLS config + tlsConfig := &tls.Config{ + ClientCAs: caCertPool, + RootCAs: caCertPool, + ClientAuth: tls.RequireAndVerifyClientCert, + // TLS 1.2 minimum + MinVersion: tls.VersionTLS12, + Certificates: []tls.Certificate{cert}, + } + tlsConfig.BuildNameToCertificate() + + return tlsConfig, nil +} From 2ef1cbf3a6b6e28f7e701c452443e02870d43015 Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Thu, 16 Mar 2017 14:14:49 -0700 Subject: [PATCH 027/152] Comment and slight refactor of the TLS plugin helper --- builtin/logical/database/dbs/plugin.go | 45 +++---------- helper/pluginutil/tls.go | 89 +++++++++++++++++++++++--- 2 files changed, 89 insertions(+), 45 deletions(-) diff --git a/builtin/logical/database/dbs/plugin.go b/builtin/logical/database/dbs/plugin.go index b4649fc7f419..c068128d8b8b 100644 --- a/builtin/logical/database/dbs/plugin.go +++ b/builtin/logical/database/dbs/plugin.go @@ -1,8 +1,6 @@ package dbs import ( - "crypto/tls" - "crypto/x509" "fmt" "net/rpc" "os/exec" @@ -56,57 +54,34 @@ func newPluginClient(sys logical.SystemView, command, checksum string) (Database "database": new(DatabasePlugin), } - CACertBytes, CACert, CAKey, err := pluginutil.GenerateX509Cert() + // Get a CA TLS Certificate + CACertBytes, CACert, CAKey, err := pluginutil.GenerateCACert() if err != nil { return nil, err } - clientCertBytes, clientCert, clientKey, err := pluginutil.GenerateClientCert(CACert, CAKey) + // Use CA to sign a client cert and return a configured TLS config + clientTLSConfig, err := pluginutil.CreateClientTLSConfig(CACert, CAKey) if err != nil { return nil, err } - /* serverCert, serverKey, err := generateClientCert(CACert, CAKey) - if err != nil { - return nil, err - }*/ - serverKey, err := x509.MarshalECPrivateKey(CAKey) + // Use CA to sign a server cert and wrap the values in a response wrapped + // token. + wrapToken, err := pluginutil.WrapServerConfig(sys, CACertBytes, CACert, CAKey) if err != nil { return nil, err } - cert := tls.Certificate{ - Certificate: [][]byte{clientCertBytes, CACertBytes}, - PrivateKey: clientKey, - Leaf: clientCert, - } - - clientCertPool := x509.NewCertPool() - clientCertPool.AddCert(CACert) - - tlsConfig := &tls.Config{ - Certificates: []tls.Certificate{cert}, - RootCAs: clientCertPool, - ClientCAs: clientCertPool, - ServerName: CACert.Subject.CommonName, - MinVersion: tls.VersionTLS12, - } - - tlsConfig.BuildNameToCertificate() - - wrapToken, err := sys.ResponseWrapData(map[string]interface{}{ - "CACert": CACertBytes, - "ServerCert": CACertBytes, - "ServerKey": serverKey, - }, time.Second*10, true) + // Add the response wrap token to the ENV of the plugin cmd := exec.Command(command) - cmd.Env = append(cmd.Env, fmt.Sprintf("VAULT_WRAP_TOKEN=%s", wrapToken)) + cmd.Env = append(cmd.Env, fmt.Sprintf("%s=%s", pluginutil.PluginUnwrapTokenEnv, wrapToken)) client := plugin.NewClient(&plugin.ClientConfig{ HandshakeConfig: handshakeConfig, Plugins: pluginMap, Cmd: cmd, - TLSConfig: tlsConfig, + TLSConfig: clientTLSConfig, }) // Connect via RPC diff --git a/helper/pluginutil/tls.go b/helper/pluginutil/tls.go index 55f27a3881cc..10ca8583a0cc 100644 --- a/helper/pluginutil/tls.go +++ b/helper/pluginutil/tls.go @@ -21,9 +21,16 @@ import ( "github.com/hashicorp/errwrap" uuid "github.com/hashicorp/go-uuid" "github.com/hashicorp/vault/api" + "github.com/hashicorp/vault/logical" ) -func GenerateX509Cert() ([]byte, *x509.Certificate, *ecdsa.PrivateKey, error) { +var ( + PluginUnwrapTokenEnv = "VAULT_WRAP_TOKEN" +) + +// GenerateCACert returns a CA cert used to later sign the certificates for the +// plugin client and server. +func GenerateCACert() ([]byte, *x509.Certificate, *ecdsa.PrivateKey, error) { key, err := ecdsa.GenerateKey(elliptic.P521(), rand.Reader) if err != nil { return nil, nil, nil, err @@ -65,7 +72,9 @@ func GenerateX509Cert() ([]byte, *x509.Certificate, *ecdsa.PrivateKey, error) { return certBytes, caCert, key, nil } -func GenerateClientCert(CACert *x509.Certificate, CAKey *ecdsa.PrivateKey) ([]byte, *x509.Certificate, []byte, error) { +// generateSignedCert is used internally to create certificates for the plugin +// client and server. These certs are signed by the given CA Cert and Key. +func generateSignedCert(CACert *x509.Certificate, CAKey *ecdsa.PrivateKey) ([]byte, *x509.Certificate, *ecdsa.PrivateKey, error) { host, err := uuid.GenerateUUID() if err != nil { return nil, nil, nil, err @@ -101,22 +110,71 @@ func GenerateClientCert(CACert *x509.Certificate, CAKey *ecdsa.PrivateKey) ([]by return nil, nil, nil, fmt.Errorf("error parsing generated replication certificate: %v", err) } - keyBytes, err := x509.MarshalECPrivateKey(clientKey) + return certBytes, clientCert, clientKey, nil +} + +// CreateClientTLSConfig creates a signed certificate and returns a configured +// TLS config. +func CreateClientTLSConfig(CACert *x509.Certificate, CAKey *ecdsa.PrivateKey) (*tls.Config, error) { + clientCertBytes, clientCert, clientKey, err := generateSignedCert(CACert, CAKey) if err != nil { - return nil, nil, nil, err + return nil, err + } + + cert := tls.Certificate{ + Certificate: [][]byte{clientCertBytes}, + PrivateKey: clientKey, + Leaf: clientCert, } - return certBytes, clientCert, keyBytes, nil + clientCertPool := x509.NewCertPool() + clientCertPool.AddCert(CACert) + + tlsConfig := &tls.Config{ + Certificates: []tls.Certificate{cert}, + RootCAs: clientCertPool, + ClientCAs: clientCertPool, + ServerName: CACert.Subject.CommonName, + MinVersion: tls.VersionTLS12, + } + + tlsConfig.BuildNameToCertificate() + + return tlsConfig, nil +} + +// WrapServerConfig is used to create a server certificate and private key, then +// wrap them in an unwrap token for later retrieval by the plugin. +func WrapServerConfig(sys logical.SystemView, CACertBytes []byte, CACert *x509.Certificate, CAKey *ecdsa.PrivateKey) (string, error) { + serverCertBytes, _, serverKey, err := generateSignedCert(CACert, CAKey) + if err != nil { + return "", err + } + rawKey, err := x509.MarshalECPrivateKey(serverKey) + if err != nil { + return "", err + } + + wrapToken, err := sys.ResponseWrapData(map[string]interface{}{ + "CACert": CACertBytes, + "ServerCert": serverCertBytes, + "ServerKey": rawKey, + }, time.Second*10, true) + + return wrapToken, err } // VaultPluginTLSProvider is run inside a plugin and retrives the response -// wrapped TLS certificate from vault. It returns a configured tlsConfig. +// wrapped TLS certificate from vault. It returns a configured TLS Config. func VaultPluginTLSProvider() (*tls.Config, error) { - unwrapToken := os.Getenv("VAULT_WRAP_TOKEN") + unwrapToken := os.Getenv(PluginUnwrapTokenEnv) + + // Ensure unwrap token is a JWT if strings.Count(unwrapToken, ".") != 2 { return nil, errors.New("Could not parse unwraptoken") } + // Parse the JWT and retrieve the vault address wt, err := jws.ParseJWT([]byte(unwrapToken)) if err != nil { return nil, errors.New(fmt.Sprintf("error decoding token: %s", err)) @@ -142,6 +200,7 @@ func VaultPluginTLSProvider() (*tls.Config, error) { return nil, errors.New(fmt.Sprintf("error parsing the vault address: %s", err)) } + // Unwrap the token clientConf := api.DefaultConfig() clientConf.Address = vaultAddr client, err := api.NewClient(clientConf) @@ -154,9 +213,10 @@ func VaultPluginTLSProvider() (*tls.Config, error) { return nil, errwrap.Wrapf("error during token unwrap request: {{err}}", err) } + // Retrieve and parse the CA Certificate CABytesRaw, ok := secret.Data["CACert"].(string) if !ok { - return nil, errors.New("error unmarshalling certificate") + return nil, errors.New("error unmarshalling CA certificate") } CABytes, err := base64.StdEncoding.DecodeString(CABytesRaw) @@ -169,6 +229,7 @@ func VaultPluginTLSProvider() (*tls.Config, error) { return nil, fmt.Errorf("error parsing certificate: %v", err) } + // Retrieve and parse the server's certificate serverCertBytesRaw, ok := secret.Data["ServerCert"].(string) if !ok { return nil, errors.New("error unmarshalling certificate") @@ -184,19 +245,27 @@ func VaultPluginTLSProvider() (*tls.Config, error) { return nil, fmt.Errorf("error parsing certificate: %v", err) } - serverKeyRaw, ok := secret.Data["ServerKey"].(string) + // Retrieve and parse the server's private key + serverKeyB64, ok := secret.Data["ServerKey"].(string) if !ok { return nil, errors.New("error unmarshalling certificate") } - serverKey, err := base64.StdEncoding.DecodeString(serverKeyRaw) + serverKeyRaw, err := base64.StdEncoding.DecodeString(serverKeyB64) + if err != nil { + return nil, fmt.Errorf("error parsing certificate: %v", err) + } + + serverKey, err := x509.ParseECPrivateKey(serverKeyRaw) if err != nil { return nil, fmt.Errorf("error parsing certificate: %v", err) } + // Add CA cert to the cert pool caCertPool := x509.NewCertPool() caCertPool.AddCert(CACert) + // Build a certificate object out of the server's cert and private key. cert := tls.Certificate{ Certificate: [][]byte{serverCertBytes}, PrivateKey: serverKey, From a878791480b02084f32b52e9b2bb2e71916886e2 Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Thu, 16 Mar 2017 14:17:44 -0700 Subject: [PATCH 028/152] Update the name of PluginUnwrapTokenEnv --- helper/pluginutil/tls.go | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/helper/pluginutil/tls.go b/helper/pluginutil/tls.go index 10ca8583a0cc..88d88689d1ae 100644 --- a/helper/pluginutil/tls.go +++ b/helper/pluginutil/tls.go @@ -25,7 +25,9 @@ import ( ) var ( - PluginUnwrapTokenEnv = "VAULT_WRAP_TOKEN" + // PluginUnwrapTokenEnv is the ENV name used to pass unwrap tokens to the + // plugin. + PluginUnwrapTokenEnv = "VAULT_UNWRAP_TOKEN" ) // GenerateCACert returns a CA cert used to later sign the certificates for the From 4043f533b8ed77e8e1090e6553abbabe6f6a28e7 Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Thu, 16 Mar 2017 16:20:18 -0700 Subject: [PATCH 029/152] Add a secure config to verify the checksum of the plugin --- builtin/logical/database/dbs/plugin.go | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/builtin/logical/database/dbs/plugin.go b/builtin/logical/database/dbs/plugin.go index c068128d8b8b..96d28bf90e24 100644 --- a/builtin/logical/database/dbs/plugin.go +++ b/builtin/logical/database/dbs/plugin.go @@ -1,6 +1,8 @@ package dbs import ( + "crypto/sha256" + "encoding/hex" "fmt" "net/rpc" "os/exec" @@ -77,11 +79,22 @@ func newPluginClient(sys logical.SystemView, command, checksum string) (Database cmd := exec.Command(command) cmd.Env = append(cmd.Env, fmt.Sprintf("%s=%s", pluginutil.PluginUnwrapTokenEnv, wrapToken)) + checksumDecoded, err := hex.DecodeString(checksum) + if err != nil { + return nil, err + } + + secureConfig := &plugin.SecureConfig{ + Checksum: checksumDecoded, + Hash: sha256.New(), + } + client := plugin.NewClient(&plugin.ClientConfig{ HandshakeConfig: handshakeConfig, Plugins: pluginMap, Cmd: cmd, TLSConfig: clientTLSConfig, + SecureConfig: secureConfig, }) // Connect via RPC From 404596e261710b7de75a1ee89ee139e47485fd12 Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Thu, 16 Mar 2017 17:51:25 -0700 Subject: [PATCH 030/152] Change the handshake config from the default --- builtin/logical/database/dbs/plugin.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/builtin/logical/database/dbs/plugin.go b/builtin/logical/database/dbs/plugin.go index 96d28bf90e24..8fdcb81f03c5 100644 --- a/builtin/logical/database/dbs/plugin.go +++ b/builtin/logical/database/dbs/plugin.go @@ -20,8 +20,8 @@ import ( // directory. It is a UX feature, not a security feature. var handshakeConfig = plugin.HandshakeConfig{ ProtocolVersion: 1, - MagicCookieKey: "BASIC_PLUGIN", - MagicCookieValue: "hello", + MagicCookieKey: "VAULT_DATABASE_PLUGIN", + MagicCookieValue: "926a0820-aea2-be28-51d6-83cdf00e8edb", } type DatabasePlugin struct { From ff6749b198a93f5e41d921819d8e3cf016e41a99 Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Thu, 16 Mar 2017 18:24:56 -0700 Subject: [PATCH 031/152] Comment and fix plugin Type function --- builtin/logical/database/dbs/plugin.go | 18 +++++++++++++++++- 1 file changed, 17 insertions(+), 1 deletion(-) diff --git a/builtin/logical/database/dbs/plugin.go b/builtin/logical/database/dbs/plugin.go index 8fdcb81f03c5..45c815a45d93 100644 --- a/builtin/logical/database/dbs/plugin.go +++ b/builtin/logical/database/dbs/plugin.go @@ -36,6 +36,8 @@ func (DatabasePlugin) Client(b *plugin.MuxBroker, c *rpc.Client) (interface{}, e return &databasePluginRPCClient{client: c}, nil } +// DatabasePluginClient embeds a databasePluginRPCClient and wraps it's close +// method to also call Close() on the plugin.Client. type DatabasePluginClient struct { client *plugin.Client sync.Mutex @@ -50,6 +52,9 @@ func (dc *DatabasePluginClient) Close() error { return err } +// newPluginClient returns a databaseRPCClient with a connection to a running +// plugin. The client is wrapped in a DatabasePluginClient object to ensure the +// plugin is killed on call of Close(). func newPluginClient(sys logical.SystemView, command, checksum string) (DatabaseType, error) { // pluginMap is the map of plugins we can dispense. var pluginMap = map[string]plugin.Plugin{ @@ -119,6 +124,9 @@ func newPluginClient(sys logical.SystemView, command, checksum string) (Database }, nil } +// NewPluginServer is called from within a plugin and wraps the provided +// DatabaseType implimentation in a databasePluginRPCServer object and starts a +// RPC server. func NewPluginServer(db DatabaseType) { dbPlugin := &DatabasePlugin{ impl: db, @@ -138,12 +146,18 @@ func NewPluginServer(db DatabaseType) { // ---- RPC client domain ---- +// databasePluginRPCClient impliments DatabaseType and is used on the client to +// make RPC calls to a plugin. type databasePluginRPCClient struct { client *rpc.Client } func (dr *databasePluginRPCClient) Type() string { - return "plugin" + var dbType string + //TODO: catch error + dr.client.Call("Plugin.Type", struct{}{}, &dbType) + + return fmt.Sprintf("plugin-%s", dbType) } func (dr *databasePluginRPCClient) CreateUser(statements Statements, username, password, expiration string) error { @@ -216,6 +230,8 @@ func (dr *databasePluginRPCClient) GenerateExpiration(duration time.Duration) (s } // ---- RPC server domain ---- + +// databasePluginRPCServer impliments DatabaseType and is run inside a plugin type databasePluginRPCServer struct { impl DatabaseType } From 2fdb3422a92353d4ead4940dc8eff5b263558d70 Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Tue, 21 Mar 2017 16:05:59 -0700 Subject: [PATCH 032/152] Verify connections regardless of if this connections is already existing --- .../database/path_config_connection.go | 56 +++++++------------ 1 file changed, 21 insertions(+), 35 deletions(-) diff --git a/builtin/logical/database/path_config_connection.go b/builtin/logical/database/path_config_connection.go index 0a99ad196263..dc8cf34e53c6 100644 --- a/builtin/logical/database/path_config_connection.go +++ b/builtin/logical/database/path_config_connection.go @@ -223,52 +223,38 @@ func (b *databaseBackend) connectionWriteHandler(factory dbs.Factory) framework. b.Lock() defer b.Unlock() - var db dbs.DatabaseType - if _, ok := b.connections[name]; ok { + db, err := factory(config, b.System()) + if err != nil { + return logical.ErrorResponse(fmt.Sprintf("Error creating database object: %s", err)), nil + } - // Don't allow the connection type to change - if b.connections[name].Type() != connType { - return logical.ErrorResponse("Can not change type of existing connection."), nil - } - } else { - db, err = factory(config, b.System()) - if err != nil { + err = db.Initialize(config.ConnectionDetails) + if err != nil { + if !strings.Contains(err.Error(), "Error Initializing Connection") { return logical.ErrorResponse(fmt.Sprintf("Error creating database object: %s", err)), nil } - err := db.Initialize(config.ConnectionDetails) - if err != nil { - if !strings.Contains(err.Error(), "Error Initializing Connection") { - return logical.ErrorResponse(fmt.Sprintf("Error creating database object: %s", err)), nil - - } - - if verifyConnection { - return logical.ErrorResponse(err.Error()), nil - - } + if verifyConnection { + return logical.ErrorResponse(err.Error()), nil } - - b.connections[name] = db } - /* TODO: - // Don't check the connection_url if verification is disabled - verifyConnection := data.Get("verify_connection").(bool) - if verifyConnection { - // Verify the string - db, err := sql.Open("postgres", connURL) + if _, ok := b.connections[name]; ok { + // Don't update connection until the reset api is hit, close for + // now. + err = db.Close() if err != nil { - return logical.ErrorResponse(fmt.Sprintf( - "Error validating connection info: %s", err)), nil + return nil, err } - defer db.Close() - if err := db.Ping(); err != nil { - return logical.ErrorResponse(fmt.Sprintf( - "Error validating connection info: %s", err)), nil + + // Don't allow the connection type to change + if b.connections[name].Type() != connType { + return logical.ErrorResponse("Can not change type of existing connection."), nil } + } else { + // Save the new connection + b.connections[name] = db } - */ // Store it entry, err := logical.StorageEntryJSON(fmt.Sprintf("dbs/%s", name), config) From 2d6f36df17c4cb921de3ea83561ce9a0fc276830 Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Tue, 21 Mar 2017 17:19:30 -0700 Subject: [PATCH 033/152] Add a delete method --- builtin/logical/database/backend.go | 2 +- .../database/path_config_connection.go | 44 +++++++++++++++++-- builtin/logical/database/path_role_create.go | 2 +- 3 files changed, 42 insertions(+), 6 deletions(-) diff --git a/builtin/logical/database/backend.go b/builtin/logical/database/backend.go index 69d91f6f278e..6108652532d8 100644 --- a/builtin/logical/database/backend.go +++ b/builtin/logical/database/backend.go @@ -27,7 +27,7 @@ func Backend(conf *logical.BackendConfig) *databaseBackend { }, Paths: []*framework.Path{ - pathConfigureConnection(&b), + pathConfigureBuiltinConnection(&b), pathConfigurePluginConnection(&b), pathListRoles(&b), pathRoles(&b), diff --git a/builtin/logical/database/path_config_connection.go b/builtin/logical/database/path_config_connection.go index dc8cf34e53c6..3bb3a5631033 100644 --- a/builtin/logical/database/path_config_connection.go +++ b/builtin/logical/database/path_config_connection.go @@ -78,15 +78,22 @@ func (b *databaseBackend) pathConnectionReset(req *logical.Request, data *framew return nil, nil } -func pathConfigureConnection(b *databaseBackend) *framework.Path { - return buildConfigConnectionPath("dbs/%s", b.connectionWriteHandler(dbs.BuiltinFactory), b.connectionReadHandler()) +// pathConfigureBuiltinConnection returns a configured framework.Path setup to +// operate on builtin databases. +func pathConfigureBuiltinConnection(b *databaseBackend) *framework.Path { + return buildConfigConnectionPath("dbs/%s", b.connectionWriteHandler(dbs.BuiltinFactory), b.connectionReadHandler(), b.connectionDeleteHandler()) } +// pathConfigurePluginConnection returns a configured framework.Path setup to +// operate on plugins. func pathConfigurePluginConnection(b *databaseBackend) *framework.Path { - return buildConfigConnectionPath("dbs/plugin/%s", b.connectionWriteHandler(dbs.PluginFactory), b.connectionReadHandler()) + return buildConfigConnectionPath("dbs/plugin/%s", b.connectionWriteHandler(dbs.PluginFactory), b.connectionReadHandler(), b.connectionDeleteHandler()) } -func buildConfigConnectionPath(path string, updateOp, readOp framework.OperationFunc) *framework.Path { +// buildConfigConnectionPath reutns a configured framework.Path using the passed +// in operation functions to complete the request. Used to distinguish calls +// between builtin and plugin databases. +func buildConfigConnectionPath(path string, updateOp, readOp, deleteOp framework.OperationFunc) *framework.Path { return &framework.Path{ Pattern: fmt.Sprintf(path, framework.GenericNameRegex("name")), Fields: map[string]*framework.FieldSchema{ @@ -145,6 +152,7 @@ reduced to the same size.`, Callbacks: map[logical.Operation]framework.OperationFunc{ logical.UpdateOperation: updateOp, logical.ReadOperation: readOp, + logical.DeleteOperation: deleteOp, }, HelpSynopsis: pathConfigConnectionHelpSyn, @@ -175,6 +183,34 @@ func (b *databaseBackend) connectionReadHandler() framework.OperationFunc { } } +// connectionDeleteHandler deletes the connection configuration +func (b *databaseBackend) connectionDeleteHandler() framework.OperationFunc { + return func(req *logical.Request, data *framework.FieldData) (*logical.Response, error) { + name := data.Get("name").(string) + if name == "" { + return logical.ErrorResponse("Empty name attribute given"), nil + } + + err := req.Storage.Delete(fmt.Sprintf("dbs/%s", name)) + if err != nil { + return nil, fmt.Errorf("failed to delete connection configuration") + } + + if _, ok := b.connections[name]; ok { + err = b.connections[name].Close() + if err != nil { + return nil, err + } + } + + delete(b.connections, name) + + return nil, nil + } +} + +// connectionWriteHandler returns a handler function for creating and updating +// both builtin and plugin database types. func (b *databaseBackend) connectionWriteHandler(factory dbs.Factory) framework.OperationFunc { return func(req *logical.Request, data *framework.FieldData) (*logical.Response, error) { connType := data.Get("connection_type").(string) diff --git a/builtin/logical/database/path_role_create.go b/builtin/logical/database/path_role_create.go index c7989c25d870..14b65cbb3106 100644 --- a/builtin/logical/database/path_role_create.go +++ b/builtin/logical/database/path_role_create.go @@ -42,7 +42,7 @@ func (b *databaseBackend) pathRoleCreateRead(req *logical.Request, data *framewo return logical.ErrorResponse(fmt.Sprintf("unknown role: %s", name)), nil } - // Generate the username, password and expiration. PG limits user to 63 characters + // Generate the username, password and expiration // Get our handle b.logger.Trace("postgres/pathRoleCreateRead: getting database handle") From 1be813605fde65ba6d4c862da61200f098300793 Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Wed, 22 Mar 2017 09:54:19 -0700 Subject: [PATCH 034/152] Fix race with deleting the connection --- builtin/logical/database/path_config_connection.go | 3 +++ 1 file changed, 3 insertions(+) diff --git a/builtin/logical/database/path_config_connection.go b/builtin/logical/database/path_config_connection.go index 3bb3a5631033..ba6a3780509a 100644 --- a/builtin/logical/database/path_config_connection.go +++ b/builtin/logical/database/path_config_connection.go @@ -196,6 +196,9 @@ func (b *databaseBackend) connectionDeleteHandler() framework.OperationFunc { return nil, fmt.Errorf("failed to delete connection configuration") } + b.Lock() + defer b.Unlock() + if _, ok := b.connections[name]; ok { err = b.connections[name].Close() if err != nil { From 9aaec25a4edaa1bd077fbd60edb2c502741d2a42 Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Wed, 22 Mar 2017 12:40:16 -0700 Subject: [PATCH 035/152] Add a error message for empty creation statement --- builtin/logical/database/dbs/db.go | 1 + builtin/logical/database/dbs/mysql.go | 3 +-- builtin/logical/database/dbs/postgresql.go | 4 ++++ 3 files changed, 6 insertions(+), 2 deletions(-) diff --git a/builtin/logical/database/dbs/db.go b/builtin/logical/database/dbs/db.go index b681de360bc4..4554963ac7f4 100644 --- a/builtin/logical/database/dbs/db.go +++ b/builtin/logical/database/dbs/db.go @@ -18,6 +18,7 @@ const ( var ( ErrUnsupportedDatabaseType = errors.New("Unsupported database type") + ErrEmptyCreationStatement = errors.New("Empty creation statements") ) type Factory func(*DatabaseConfig, logical.SystemView) (DatabaseType, error) diff --git a/builtin/logical/database/dbs/mysql.go b/builtin/logical/database/dbs/mysql.go index 0d8be2a470f3..54940d8f65fc 100644 --- a/builtin/logical/database/dbs/mysql.go +++ b/builtin/logical/database/dbs/mysql.go @@ -2,7 +2,6 @@ package dbs import ( "database/sql" - "fmt" "strings" "github.com/hashicorp/vault/helper/strutil" @@ -43,7 +42,7 @@ func (m *MySQL) CreateUser(statements Statements, username, password, expiration } if statements.CreationStatements == "" { - return fmt.Errorf("Empty creation statements") + return ErrEmptyCreationStatement } // Start a transaction diff --git a/builtin/logical/database/dbs/postgresql.go b/builtin/logical/database/dbs/postgresql.go index 51b72ebc8a52..20d548f9204f 100644 --- a/builtin/logical/database/dbs/postgresql.go +++ b/builtin/logical/database/dbs/postgresql.go @@ -28,6 +28,10 @@ func (p *PostgreSQL) getConnection() (*sql.DB, error) { } func (p *PostgreSQL) CreateUser(statements Statements, username, password, expiration string) error { + if statements.CreationStatements == "" { + return ErrEmptyCreationStatement + } + // Grab the lock p.Lock() defer p.Unlock() From 73e553af95b9d5c8380e317f898c28ba3d2960e8 Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Wed, 22 Mar 2017 16:39:08 -0700 Subject: [PATCH 036/152] Add test files for postgres and mysql databases --- builtin/logical/database/dbs/mysql_test.go | 349 +++++++++++++++ .../logical/database/dbs/postgresql_test.go | 412 ++++++++++++++++++ 2 files changed, 761 insertions(+) create mode 100644 builtin/logical/database/dbs/mysql_test.go create mode 100644 builtin/logical/database/dbs/postgresql_test.go diff --git a/builtin/logical/database/dbs/mysql_test.go b/builtin/logical/database/dbs/mysql_test.go new file mode 100644 index 000000000000..a27dfbba751b --- /dev/null +++ b/builtin/logical/database/dbs/mysql_test.go @@ -0,0 +1,349 @@ +package dbs + +import ( + "database/sql" + "os" + "sync" + "testing" + "time" + + dockertest "gopkg.in/ory-am/dockertest.v2" +) + +var ( + testMySQLImagePull sync.Once +) + +func prepareMySQLTestContainer(t *testing.T) (cid dockertest.ContainerID, retURL string) { + if os.Getenv("PG_URL") != "" { + return "", os.Getenv("PG_URL") + } + + // Without this the checks for whether the container has started seem to + // never actually pass. There's really no reason to expose the test + // containers, so don't. + dockertest.BindDockerToLocalhost = "yep" + + testImagePull.Do(func() { + dockertest.Pull("mysql") + }) + + cid, connErr := dockertest.ConnectToMySQL(60, 500*time.Millisecond, func(connURL string) bool { + // This will cause a validation to run + connProducer := &sqlConnectionProducer{} + connProducer.ConnectionURL = connURL + connProducer.config = &DatabaseConfig{ + DatabaseType: mySQLTypeName, + } + + conn, err := connProducer.connection() + if err != nil { + return false + } + if err := conn.(*sql.DB).Ping(); err != nil { + return false + } + + connProducer.Close() + + retURL = connURL + return true + }) + + if connErr != nil { + t.Fatalf("could not connect to database: %v", connErr) + } + + return +} + +func TestMySQL_Initialize(t *testing.T) { + cid, connURL := prepareMySQLTestContainer(t) + if cid != "" { + defer cleanupTestContainer(t, cid) + } + + conf := &DatabaseConfig{ + DatabaseType: mySQLTypeName, + ConnectionDetails: map[string]interface{}{ + "connection_url": connURL, + }, + } + + dbRaw, err := BuiltinFactory(conf, nil) + if err != nil { + t.Fatalf("err: %s", err) + } + + // Deconsturct the middleware chain to get the underlying postgres object + dbMetrics := dbRaw.(*databaseMetricsMiddleware) + db := dbMetrics.next.(*MySQL) + + err = dbRaw.Initialize(conf.ConnectionDetails) + if err != nil { + t.Fatalf("err: %s", err) + } + + connProducer := db.ConnectionProducer.(*sqlConnectionProducer) + if !connProducer.initalized { + t.Fatal("Database should be initalized") + } + + err = dbRaw.Close() + if err != nil { + t.Fatalf("err: %s", err) + } + + if connProducer.db != nil { + t.Fatal("db object should be nil") + } +} + +func TestMySQL_CreateUser(t *testing.T) { + cid, connURL := prepareMySQLTestContainer(t) + if cid != "" { + defer cleanupTestContainer(t, cid) + } + + conf := &DatabaseConfig{ + DatabaseType: mySQLTypeName, + ConnectionDetails: map[string]interface{}{ + "connection_url": connURL, + }, + } + + db, err := BuiltinFactory(conf, nil) + if err != nil { + t.Fatalf("err: %s", err) + } + err = db.Initialize(conf.ConnectionDetails) + if err != nil { + t.Fatalf("err: %s", err) + } + + username, err := db.GenerateUsername("test") + if err != nil { + t.Fatalf("err: %s", err) + } + + password, err := db.GeneratePassword() + if err != nil { + t.Fatalf("err: %s", err) + } + + expiration, err := db.GenerateExpiration(time.Minute) + if err != nil { + t.Fatalf("err: %s", err) + } + + // Test with no configured Creation Statememt + err = db.CreateUser(Statements{}, username, password, expiration) + if err == nil { + t.Fatal("Expected error when no creation statement is provided") + } + + statements := Statements{ + CreationStatements: testMySQLRoleWildCard, + } + + err = db.CreateUser(statements, username, password, expiration) + if err != nil { + t.Fatalf("err: %s", err) + } + + username, err = db.GenerateUsername("test") + if err != nil { + t.Fatalf("err: %s", err) + } + + password, err = db.GeneratePassword() + if err != nil { + t.Fatalf("err: %s", err) + } + + expiration, err = db.GenerateExpiration(time.Minute) + if err != nil { + t.Fatalf("err: %s", err) + } + statements.CreationStatements = testMySQLRoleHost + err = db.CreateUser(statements, username, password, expiration) + if err != nil { + t.Fatalf("err: %s", err) + } + + username, err = db.GenerateUsername("test") + if err != nil { + t.Fatalf("err: %s", err) + } + + password, err = db.GeneratePassword() + if err != nil { + t.Fatalf("err: %s", err) + } + + expiration, err = db.GenerateExpiration(time.Minute) + if err != nil { + t.Fatalf("err: %s", err) + } + /* statements.CreationStatements = testBlockStatementRole + err = db.CreateUser(statements, username, password, expiration) + if err != nil { + t.Fatalf("err: %s", err) + }*/ +} + +func TestMySQL_RenewUser(t *testing.T) { + cid, connURL := prepareMySQLTestContainer(t) + if cid != "" { + defer cleanupTestContainer(t, cid) + } + + conf := &DatabaseConfig{ + DatabaseType: mySQLTypeName, + ConnectionDetails: map[string]interface{}{ + "connection_url": connURL, + }, + } + + db, err := BuiltinFactory(conf, nil) + if err != nil { + t.Fatalf("err: %s", err) + } + err = db.Initialize(conf.ConnectionDetails) + if err != nil { + t.Fatalf("err: %s", err) + } + + username, err := db.GenerateUsername("test") + if err != nil { + t.Fatalf("err: %s", err) + } + + password, err := db.GeneratePassword() + if err != nil { + t.Fatalf("err: %s", err) + } + + expiration, err := db.GenerateExpiration(time.Minute) + if err != nil { + t.Fatalf("err: %s", err) + } + + statements := Statements{ + CreationStatements: testMySQLRoleWildCard, + } + + err = db.CreateUser(statements, username, password, expiration) + if err != nil { + t.Fatalf("err: %s", err) + } + + expiration, err = db.GenerateExpiration(time.Minute) + if err != nil { + t.Fatalf("err: %s", err) + } + err = db.RenewUser(statements, username, expiration) + if err != nil { + t.Fatalf("err: %s", err) + } +} + +func TestMySQL_RevokeUser(t *testing.T) { + cid, connURL := prepareMySQLTestContainer(t) + if cid != "" { + defer cleanupTestContainer(t, cid) + } + + conf := &DatabaseConfig{ + DatabaseType: mySQLTypeName, + ConnectionDetails: map[string]interface{}{ + "connection_url": connURL, + }, + } + + db, err := BuiltinFactory(conf, nil) + if err != nil { + t.Fatalf("err: %s", err) + } + err = db.Initialize(conf.ConnectionDetails) + if err != nil { + t.Fatalf("err: %s", err) + } + + username, err := db.GenerateUsername("test") + if err != nil { + t.Fatalf("err: %s", err) + } + + password, err := db.GeneratePassword() + if err != nil { + t.Fatalf("err: %s", err) + } + + expiration, err := db.GenerateExpiration(time.Minute) + if err != nil { + t.Fatalf("err: %s", err) + } + + statements := Statements{ + CreationStatements: testMySQLRoleWildCard, + } + + err = db.CreateUser(statements, username, password, expiration) + if err != nil { + t.Fatalf("err: %s", err) + } + + expiration, err = db.GenerateExpiration(time.Minute) + if err != nil { + t.Fatalf("err: %s", err) + } + + // Test default revoke statememts + err = db.RevokeUser(statements, username) + if err != nil { + t.Fatalf("err: %s", err) + } + + username, err = db.GenerateUsername("test") + if err != nil { + t.Fatalf("err: %s", err) + } + + password, err = db.GeneratePassword() + if err != nil { + t.Fatalf("err: %s", err) + } + + expiration, err = db.GenerateExpiration(time.Minute) + if err != nil { + t.Fatalf("err: %s", err) + } + + statements.CreationStatements = testMySQLRoleHost + err = db.CreateUser(statements, username, password, expiration) + if err != nil { + t.Fatalf("err: %s", err) + } + + // Test custom revoke statements + statements.RevocationStatements = testMySQLRevocationSQL + err = db.RevokeUser(statements, username) + if err != nil { + t.Fatalf("err: %s", err) + } + +} + +const testMySQLRoleWildCard = ` +CREATE USER '{{name}}'@'%' IDENTIFIED BY '{{password}}'; +GRANT SELECT ON *.* TO '{{name}}'@'%'; +` +const testMySQLRoleHost = ` +CREATE USER '{{name}}'@'10.1.1.2' IDENTIFIED BY '{{password}}'; +GRANT SELECT ON *.* TO '{{name}}'@'10.1.1.2'; +` +const testMySQLRevocationSQL = ` +REVOKE ALL PRIVILEGES, GRANT OPTION FROM '{{name}}'@'10.1.1.2'; +DROP USER '{{name}}'@'10.1.1.2'; +` diff --git a/builtin/logical/database/dbs/postgresql_test.go b/builtin/logical/database/dbs/postgresql_test.go new file mode 100644 index 000000000000..211ab0254186 --- /dev/null +++ b/builtin/logical/database/dbs/postgresql_test.go @@ -0,0 +1,412 @@ +package dbs + +import ( + "database/sql" + "os" + "sync" + "testing" + "time" + + dockertest "gopkg.in/ory-am/dockertest.v2" +) + +var ( + testImagePull sync.Once +) + +func prepareTestContainer(t *testing.T) (cid dockertest.ContainerID, retURL string) { + if os.Getenv("PG_URL") != "" { + return "", os.Getenv("PG_URL") + } + + // Without this the checks for whether the container has started seem to + // never actually pass. There's really no reason to expose the test + // containers, so don't. + dockertest.BindDockerToLocalhost = "yep" + + testImagePull.Do(func() { + dockertest.Pull("postgres") + }) + + cid, connErr := dockertest.ConnectToPostgreSQL(60, 500*time.Millisecond, func(connURL string) bool { + // This will cause a validation to run + connProducer := &sqlConnectionProducer{} + connProducer.ConnectionURL = connURL + connProducer.config = &DatabaseConfig{ + DatabaseType: postgreSQLTypeName, + } + + conn, err := connProducer.connection() + if err != nil { + return false + } + if err := conn.(*sql.DB).Ping(); err != nil { + return false + } + + connProducer.Close() + + retURL = connURL + return true + }) + + if connErr != nil { + t.Fatalf("could not connect to database: %v", connErr) + } + + return +} + +func cleanupTestContainer(t *testing.T, cid dockertest.ContainerID) { + err := cid.KillRemove() + if err != nil { + t.Fatal(err) + } +} + +func TestPostgreSQL_Initialize(t *testing.T) { + cid, connURL := prepareTestContainer(t) + if cid != "" { + defer cleanupTestContainer(t, cid) + } + + conf := &DatabaseConfig{ + DatabaseType: postgreSQLTypeName, + ConnectionDetails: map[string]interface{}{ + "connection_url": connURL, + }, + } + + dbRaw, err := BuiltinFactory(conf, nil) + if err != nil { + t.Fatalf("err: %s", err) + } + + // Deconsturct the middleware chain to get the underlying postgres object + dbMetrics := dbRaw.(*databaseMetricsMiddleware) + db := dbMetrics.next.(*PostgreSQL) + + err = dbRaw.Initialize(conf.ConnectionDetails) + if err != nil { + t.Fatalf("err: %s", err) + } + + connProducer := db.ConnectionProducer.(*sqlConnectionProducer) + if !connProducer.initalized { + t.Fatal("Database should be initalized") + } + + err = dbRaw.Close() + if err != nil { + t.Fatalf("err: %s", err) + } + + if connProducer.db != nil { + t.Fatal("db object should be nil") + } +} + +func TestPostgreSQL_CreateUser(t *testing.T) { + cid, connURL := prepareTestContainer(t) + if cid != "" { + defer cleanupTestContainer(t, cid) + } + + conf := &DatabaseConfig{ + DatabaseType: postgreSQLTypeName, + ConnectionDetails: map[string]interface{}{ + "connection_url": connURL, + }, + } + + db, err := BuiltinFactory(conf, nil) + if err != nil { + t.Fatalf("err: %s", err) + } + err = db.Initialize(conf.ConnectionDetails) + if err != nil { + t.Fatalf("err: %s", err) + } + + username, err := db.GenerateUsername("test") + if err != nil { + t.Fatalf("err: %s", err) + } + + password, err := db.GeneratePassword() + if err != nil { + t.Fatalf("err: %s", err) + } + + expiration, err := db.GenerateExpiration(time.Minute) + if err != nil { + t.Fatalf("err: %s", err) + } + + // Test with no configured Creation Statememt + err = db.CreateUser(Statements{}, username, password, expiration) + if err == nil { + t.Fatal("Expected error when no creation statement is provided") + } + + statements := Statements{ + CreationStatements: testRole, + } + + err = db.CreateUser(statements, username, password, expiration) + if err != nil { + t.Fatalf("err: %s", err) + } + + username, err = db.GenerateUsername("test") + if err != nil { + t.Fatalf("err: %s", err) + } + + password, err = db.GeneratePassword() + if err != nil { + t.Fatalf("err: %s", err) + } + + expiration, err = db.GenerateExpiration(time.Minute) + if err != nil { + t.Fatalf("err: %s", err) + } + statements.CreationStatements = testReadOnlyRole + err = db.CreateUser(statements, username, password, expiration) + if err != nil { + t.Fatalf("err: %s", err) + } + + username, err = db.GenerateUsername("test") + if err != nil { + t.Fatalf("err: %s", err) + } + + password, err = db.GeneratePassword() + if err != nil { + t.Fatalf("err: %s", err) + } + + expiration, err = db.GenerateExpiration(time.Minute) + if err != nil { + t.Fatalf("err: %s", err) + } + /* statements.CreationStatements = testBlockStatementRole + err = db.CreateUser(statements, username, password, expiration) + if err != nil { + t.Fatalf("err: %s", err) + }*/ +} + +func TestPostgreSQL_RenewUser(t *testing.T) { + cid, connURL := prepareTestContainer(t) + if cid != "" { + defer cleanupTestContainer(t, cid) + } + + conf := &DatabaseConfig{ + DatabaseType: postgreSQLTypeName, + ConnectionDetails: map[string]interface{}{ + "connection_url": connURL, + }, + } + + db, err := BuiltinFactory(conf, nil) + if err != nil { + t.Fatalf("err: %s", err) + } + err = db.Initialize(conf.ConnectionDetails) + if err != nil { + t.Fatalf("err: %s", err) + } + + username, err := db.GenerateUsername("test") + if err != nil { + t.Fatalf("err: %s", err) + } + + password, err := db.GeneratePassword() + if err != nil { + t.Fatalf("err: %s", err) + } + + expiration, err := db.GenerateExpiration(time.Minute) + if err != nil { + t.Fatalf("err: %s", err) + } + + statements := Statements{ + CreationStatements: testRole, + } + + err = db.CreateUser(statements, username, password, expiration) + if err != nil { + t.Fatalf("err: %s", err) + } + + expiration, err = db.GenerateExpiration(time.Minute) + if err != nil { + t.Fatalf("err: %s", err) + } + err = db.RenewUser(statements, username, expiration) + if err != nil { + t.Fatalf("err: %s", err) + } +} + +func TestPostgreSQL_RevokeUser(t *testing.T) { + cid, connURL := prepareTestContainer(t) + if cid != "" { + defer cleanupTestContainer(t, cid) + } + + conf := &DatabaseConfig{ + DatabaseType: postgreSQLTypeName, + ConnectionDetails: map[string]interface{}{ + "connection_url": connURL, + }, + } + + db, err := BuiltinFactory(conf, nil) + if err != nil { + t.Fatalf("err: %s", err) + } + err = db.Initialize(conf.ConnectionDetails) + if err != nil { + t.Fatalf("err: %s", err) + } + + username, err := db.GenerateUsername("test") + if err != nil { + t.Fatalf("err: %s", err) + } + + password, err := db.GeneratePassword() + if err != nil { + t.Fatalf("err: %s", err) + } + + expiration, err := db.GenerateExpiration(time.Minute) + if err != nil { + t.Fatalf("err: %s", err) + } + + statements := Statements{ + CreationStatements: testRole, + } + + err = db.CreateUser(statements, username, password, expiration) + if err != nil { + t.Fatalf("err: %s", err) + } + + expiration, err = db.GenerateExpiration(time.Minute) + if err != nil { + t.Fatalf("err: %s", err) + } + + // Test default revoke statememts + err = db.RevokeUser(statements, username) + if err != nil { + t.Fatalf("err: %s", err) + } + + username, err = db.GenerateUsername("test") + if err != nil { + t.Fatalf("err: %s", err) + } + + password, err = db.GeneratePassword() + if err != nil { + t.Fatalf("err: %s", err) + } + + expiration, err = db.GenerateExpiration(time.Minute) + if err != nil { + t.Fatalf("err: %s", err) + } + + err = db.CreateUser(statements, username, password, expiration) + if err != nil { + t.Fatalf("err: %s", err) + } + + // Test custom revoke statements + statements.RevocationStatements = defaultRevocationSQL + err = db.RevokeUser(statements, username) + if err != nil { + t.Fatalf("err: %s", err) + } + +} + +const testRole = ` +CREATE ROLE "{{name}}" WITH + LOGIN + PASSWORD '{{password}}' + VALID UNTIL '{{expiration}}'; +GRANT ALL PRIVILEGES ON ALL TABLES IN SCHEMA public TO "{{name}}"; +` + +const testReadOnlyRole = ` +CREATE ROLE "{{name}}" WITH + LOGIN + PASSWORD '{{password}}' + VALID UNTIL '{{expiration}}'; +GRANT SELECT ON ALL TABLES IN SCHEMA public TO "{{name}}"; +GRANT SELECT ON ALL SEQUENCES IN SCHEMA public TO "{{name}}"; +` + +const testBlockStatementRole = ` +DO $$ +BEGIN + IF NOT EXISTS (SELECT * FROM pg_catalog.pg_roles WHERE rolname='foo-role') THEN + CREATE ROLE "foo-role"; + CREATE SCHEMA IF NOT EXISTS foo AUTHORIZATION "foo-role"; + ALTER ROLE "foo-role" SET search_path = foo; + GRANT TEMPORARY ON DATABASE "postgres" TO "foo-role"; + GRANT ALL PRIVILEGES ON SCHEMA foo TO "foo-role"; + GRANT ALL PRIVILEGES ON ALL TABLES IN SCHEMA foo TO "foo-role"; + GRANT ALL PRIVILEGES ON ALL SEQUENCES IN SCHEMA foo TO "foo-role"; + GRANT ALL PRIVILEGES ON ALL FUNCTIONS IN SCHEMA foo TO "foo-role"; + END IF; +END +$$ + +CREATE ROLE "{{name}}" WITH LOGIN PASSWORD '{{password}}' VALID UNTIL '{{expiration}}'; +GRANT "foo-role" TO "{{name}}"; +ALTER ROLE "{{name}}" SET search_path = foo; +GRANT CONNECT ON DATABASE "postgres" TO "{{name}}"; +` + +var testBlockStatementRoleSlice = []string{ + ` +DO $$ +BEGIN + IF NOT EXISTS (SELECT * FROM pg_catalog.pg_roles WHERE rolname='foo-role') THEN + CREATE ROLE "foo-role"; + CREATE SCHEMA IF NOT EXISTS foo AUTHORIZATION "foo-role"; + ALTER ROLE "foo-role" SET search_path = foo; + GRANT TEMPORARY ON DATABASE "postgres" TO "foo-role"; + GRANT ALL PRIVILEGES ON SCHEMA foo TO "foo-role"; + GRANT ALL PRIVILEGES ON ALL TABLES IN SCHEMA foo TO "foo-role"; + GRANT ALL PRIVILEGES ON ALL SEQUENCES IN SCHEMA foo TO "foo-role"; + GRANT ALL PRIVILEGES ON ALL FUNCTIONS IN SCHEMA foo TO "foo-role"; + END IF; +END +$$ +`, + `CREATE ROLE "{{name}}" WITH LOGIN PASSWORD '{{password}}' VALID UNTIL '{{expiration}}';`, + `GRANT "foo-role" TO "{{name}}";`, + `ALTER ROLE "{{name}}" SET search_path = foo;`, + `GRANT CONNECT ON DATABASE "postgres" TO "{{name}}";`, +} + +const defaultRevocationSQL = ` +REVOKE ALL PRIVILEGES ON ALL TABLES IN SCHEMA public FROM "{{name}}"; +REVOKE ALL PRIVILEGES ON ALL SEQUENCES IN SCHEMA public FROM "{{name}}"; +REVOKE USAGE ON SCHEMA public FROM "{{name}}"; + +DROP ROLE IF EXISTS "{{name}}"; +` From cab491f7b72298c456393862fce12b3c05b311af Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Wed, 22 Mar 2017 16:44:33 -0700 Subject: [PATCH 037/152] s/postgres/mysql/ --- builtin/logical/database/dbs/mysql_test.go | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/builtin/logical/database/dbs/mysql_test.go b/builtin/logical/database/dbs/mysql_test.go index a27dfbba751b..e7edabeb7bc0 100644 --- a/builtin/logical/database/dbs/mysql_test.go +++ b/builtin/logical/database/dbs/mysql_test.go @@ -15,8 +15,8 @@ var ( ) func prepareMySQLTestContainer(t *testing.T) (cid dockertest.ContainerID, retURL string) { - if os.Getenv("PG_URL") != "" { - return "", os.Getenv("PG_URL") + if os.Getenv("MYSQL_URL") != "" { + return "", os.Getenv("MYSQL_URL") } // Without this the checks for whether the container has started seem to @@ -75,7 +75,7 @@ func TestMySQL_Initialize(t *testing.T) { t.Fatalf("err: %s", err) } - // Deconsturct the middleware chain to get the underlying postgres object + // Deconsturct the middleware chain to get the underlying mysql object dbMetrics := dbRaw.(*databaseMetricsMiddleware) db := dbMetrics.next.(*MySQL) From a1b72465ddb187e9e6fffffa3600a246d35e21e7 Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Wed, 22 Mar 2017 17:09:39 -0700 Subject: [PATCH 038/152] Remove unsused code block --- builtin/logical/database/dbs/mysql_test.go | 5 ----- 1 file changed, 5 deletions(-) diff --git a/builtin/logical/database/dbs/mysql_test.go b/builtin/logical/database/dbs/mysql_test.go index e7edabeb7bc0..c489ffaea389 100644 --- a/builtin/logical/database/dbs/mysql_test.go +++ b/builtin/logical/database/dbs/mysql_test.go @@ -185,11 +185,6 @@ func TestMySQL_CreateUser(t *testing.T) { if err != nil { t.Fatalf("err: %s", err) } - /* statements.CreationStatements = testBlockStatementRole - err = db.CreateUser(statements, username, password, expiration) - if err != nil { - t.Fatalf("err: %s", err) - }*/ } func TestMySQL_RenewUser(t *testing.T) { From e870e399a272ef78954374bce08d6c7c91c78472 Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Thu, 23 Mar 2017 15:54:15 -0700 Subject: [PATCH 039/152] More work on getting tests to pass --- builtin/logical/database/backend_test.go | 620 ------------------ builtin/logical/database/dbs/mysql_test.go | 2 +- builtin/logical/database/dbs/plugin.go | 9 +- builtin/logical/database/dbs/plugin_test.go | 325 +++++++++ .../logical/database/dbs/postgresql_test.go | 34 +- helper/pluginutil/tls.go | 7 +- vault/testing.go | 16 + 7 files changed, 369 insertions(+), 644 deletions(-) delete mode 100644 builtin/logical/database/backend_test.go create mode 100644 builtin/logical/database/dbs/plugin_test.go diff --git a/builtin/logical/database/backend_test.go b/builtin/logical/database/backend_test.go deleted file mode 100644 index a203c9b19145..000000000000 --- a/builtin/logical/database/backend_test.go +++ /dev/null @@ -1,620 +0,0 @@ -package database - -import ( - "database/sql" - "encoding/json" - "fmt" - "log" - "os" - "path" - "reflect" - "sync" - "testing" - "time" - - "github.com/hashicorp/vault/logical" - logicaltest "github.com/hashicorp/vault/logical/testing" - "github.com/lib/pq" - "github.com/mitchellh/mapstructure" - "github.com/ory-am/dockertest" -) - -var ( - testImagePull sync.Once -) - -func prepareTestContainer(t *testing.T, s logical.Storage, b logical.Backend) (cid dockertest.ContainerID, retURL string) { - if os.Getenv("PG_URL") != "" { - return "", os.Getenv("PG_URL") - } - - // Without this the checks for whether the container has started seem to - // never actually pass. There's really no reason to expose the test - // containers, so don't. - dockertest.BindDockerToLocalhost = "yep" - - testImagePull.Do(func() { - dockertest.Pull("postgres") - }) - - cid, connErr := dockertest.ConnectToPostgreSQL(60, 500*time.Millisecond, func(connURL string) bool { - // This will cause a validation to run - resp, err := b.HandleRequest(&logical.Request{ - Storage: s, - Operation: logical.UpdateOperation, - Path: "config/connection", - Data: map[string]interface{}{ - "connection_url": connURL, - }, - }) - if err != nil || (resp != nil && resp.IsError()) { - // It's likely not up and running yet, so return false and try again - return false - } - if resp == nil { - t.Fatal("expected warning") - } - - retURL = connURL - return true - }) - - if connErr != nil { - t.Fatalf("could not connect to database: %v", connErr) - } - - return -} - -func cleanupTestContainer(t *testing.T, cid dockertest.ContainerID) { - err := cid.KillRemove() - if err != nil { - t.Fatal(err) - } -} - -func TestBackend_config_connection(t *testing.T) { - var resp *logical.Response - var err error - config := logical.TestBackendConfig() - config.StorageView = &logical.InmemStorage{} - b, err := Factory(config) - if err != nil { - t.Fatal(err) - } - - configData := map[string]interface{}{ - "connection_url": "sample_connection_url", - "value": "", - "max_open_connections": 9, - "max_idle_connections": 7, - "verify_connection": false, - } - - configReq := &logical.Request{ - Operation: logical.UpdateOperation, - Path: "config/connection", - Storage: config.StorageView, - Data: configData, - } - resp, err = b.HandleRequest(configReq) - if err != nil || (resp != nil && resp.IsError()) { - t.Fatalf("err:%s resp:%#v\n", err, resp) - } - - configReq.Operation = logical.ReadOperation - resp, err = b.HandleRequest(configReq) - if err != nil || (resp != nil && resp.IsError()) { - t.Fatalf("err:%s resp:%#v\n", err, resp) - } - - delete(configData, "verify_connection") - if !reflect.DeepEqual(configData, resp.Data) { - t.Fatalf("bad: expected:%#v\nactual:%#v\n", configData, resp.Data) - } -} - -func TestBackend_basic(t *testing.T) { - config := logical.TestBackendConfig() - config.StorageView = &logical.InmemStorage{} - b, err := Factory(config) - if err != nil { - t.Fatal(err) - } - - cid, connURL := prepareTestContainer(t, config.StorageView, b) - if cid != "" { - defer cleanupTestContainer(t, cid) - } - connData := map[string]interface{}{ - "connection_url": connURL, - } - - logicaltest.Test(t, logicaltest.TestCase{ - Backend: b, - Steps: []logicaltest.TestStep{ - testAccStepConfig(t, connData, false), - testAccStepCreateRole(t, "web", testRole, false), - testAccStepReadCreds(t, b, config.StorageView, "web", connURL), - }, - }) -} - -func TestBackend_roleCrud(t *testing.T) { - config := logical.TestBackendConfig() - config.StorageView = &logical.InmemStorage{} - b, err := Factory(config) - if err != nil { - t.Fatal(err) - } - - cid, connURL := prepareTestContainer(t, config.StorageView, b) - if cid != "" { - defer cleanupTestContainer(t, cid) - } - connData := map[string]interface{}{ - "connection_url": connURL, - } - - logicaltest.Test(t, logicaltest.TestCase{ - Backend: b, - Steps: []logicaltest.TestStep{ - testAccStepConfig(t, connData, false), - testAccStepCreateRole(t, "web", testRole, false), - testAccStepReadRole(t, "web", testRole), - testAccStepDeleteRole(t, "web"), - testAccStepReadRole(t, "web", ""), - }, - }) -} - -func TestBackend_BlockStatements(t *testing.T) { - config := logical.TestBackendConfig() - config.StorageView = &logical.InmemStorage{} - b, err := Factory(config) - if err != nil { - t.Fatal(err) - } - - cid, connURL := prepareTestContainer(t, config.StorageView, b) - if cid != "" { - defer cleanupTestContainer(t, cid) - } - connData := map[string]interface{}{ - "connection_url": connURL, - } - - jsonBlockStatement, err := json.Marshal(testBlockStatementRoleSlice) - if err != nil { - t.Fatal(err) - } - - logicaltest.Test(t, logicaltest.TestCase{ - Backend: b, - Steps: []logicaltest.TestStep{ - testAccStepConfig(t, connData, false), - // This will also validate the query - testAccStepCreateRole(t, "web-block", testBlockStatementRole, true), - testAccStepCreateRole(t, "web-block", string(jsonBlockStatement), false), - }, - }) -} - -func TestBackend_roleReadOnly(t *testing.T) { - config := logical.TestBackendConfig() - config.StorageView = &logical.InmemStorage{} - b, err := Factory(config) - if err != nil { - t.Fatal(err) - } - - cid, connURL := prepareTestContainer(t, config.StorageView, b) - if cid != "" { - defer cleanupTestContainer(t, cid) - } - connData := map[string]interface{}{ - "connection_url": connURL, - } - - logicaltest.Test(t, logicaltest.TestCase{ - Backend: b, - Steps: []logicaltest.TestStep{ - testAccStepConfig(t, connData, false), - testAccStepCreateRole(t, "web", testRole, false), - testAccStepCreateRole(t, "web-readonly", testReadOnlyRole, false), - testAccStepReadRole(t, "web-readonly", testReadOnlyRole), - testAccStepCreateTable(t, b, config.StorageView, "web", connURL), - testAccStepReadCreds(t, b, config.StorageView, "web-readonly", connURL), - testAccStepDropTable(t, b, config.StorageView, "web", connURL), - testAccStepDeleteRole(t, "web-readonly"), - testAccStepDeleteRole(t, "web"), - testAccStepReadRole(t, "web-readonly", ""), - }, - }) -} - -func TestBackend_roleReadOnly_revocationSQL(t *testing.T) { - config := logical.TestBackendConfig() - config.StorageView = &logical.InmemStorage{} - b, err := Factory(config) - if err != nil { - t.Fatal(err) - } - - cid, connURL := prepareTestContainer(t, config.StorageView, b) - if cid != "" { - defer cleanupTestContainer(t, cid) - } - connData := map[string]interface{}{ - "connection_url": connURL, - } - - logicaltest.Test(t, logicaltest.TestCase{ - Backend: b, - Steps: []logicaltest.TestStep{ - testAccStepConfig(t, connData, false), - testAccStepCreateRoleWithRevocationSQL(t, "web", testRole, defaultRevocationSQL, false), - testAccStepCreateRoleWithRevocationSQL(t, "web-readonly", testReadOnlyRole, defaultRevocationSQL, false), - testAccStepReadRole(t, "web-readonly", testReadOnlyRole), - testAccStepCreateTable(t, b, config.StorageView, "web", connURL), - testAccStepReadCreds(t, b, config.StorageView, "web-readonly", connURL), - testAccStepDropTable(t, b, config.StorageView, "web", connURL), - testAccStepDeleteRole(t, "web-readonly"), - testAccStepDeleteRole(t, "web"), - testAccStepReadRole(t, "web-readonly", ""), - }, - }) -} - -func testAccStepConfig(t *testing.T, d map[string]interface{}, expectError bool) logicaltest.TestStep { - return logicaltest.TestStep{ - Operation: logical.UpdateOperation, - Path: "config/connection", - Data: d, - ErrorOk: true, - Check: func(resp *logical.Response) error { - if expectError { - if resp.Data == nil { - return fmt.Errorf("data is nil") - } - var e struct { - Error string `mapstructure:"error"` - } - if err := mapstructure.Decode(resp.Data, &e); err != nil { - return err - } - if len(e.Error) == 0 { - return fmt.Errorf("expected error, but write succeeded.") - } - return nil - } else if resp != nil && resp.IsError() { - return fmt.Errorf("got an error response: %v", resp.Error()) - } - return nil - }, - } -} - -func testAccStepCreateRole(t *testing.T, name string, sql string, expectFail bool) logicaltest.TestStep { - return logicaltest.TestStep{ - Operation: logical.UpdateOperation, - Path: path.Join("roles", name), - Data: map[string]interface{}{ - "sql": sql, - }, - ErrorOk: expectFail, - } -} - -func testAccStepCreateRoleWithRevocationSQL(t *testing.T, name, sql, revocationSQL string, expectFail bool) logicaltest.TestStep { - return logicaltest.TestStep{ - Operation: logical.UpdateOperation, - Path: path.Join("roles", name), - Data: map[string]interface{}{ - "sql": sql, - "revocation_sql": revocationSQL, - }, - ErrorOk: expectFail, - } -} - -func testAccStepDeleteRole(t *testing.T, name string) logicaltest.TestStep { - return logicaltest.TestStep{ - Operation: logical.DeleteOperation, - Path: path.Join("roles", name), - } -} - -func testAccStepReadCreds(t *testing.T, b logical.Backend, s logical.Storage, name string, connURL string) logicaltest.TestStep { - return logicaltest.TestStep{ - Operation: logical.ReadOperation, - Path: path.Join("creds", name), - Check: func(resp *logical.Response) error { - var d struct { - Username string `mapstructure:"username"` - Password string `mapstructure:"password"` - } - if err := mapstructure.Decode(resp.Data, &d); err != nil { - return err - } - log.Printf("[TRACE] Generated credentials: %v", d) - conn, err := pq.ParseURL(connURL) - - if err != nil { - t.Fatal(err) - } - - conn += " timezone=utc" - - db, err := sql.Open("postgres", conn) - if err != nil { - t.Fatal(err) - } - - returnedRows := func() int { - stmt, err := db.Prepare("SELECT DISTINCT schemaname FROM pg_tables WHERE has_table_privilege($1, 'information_schema.role_column_grants', 'select');") - if err != nil { - return -1 - } - defer stmt.Close() - - rows, err := stmt.Query(d.Username) - if err != nil { - return -1 - } - defer rows.Close() - - i := 0 - for rows.Next() { - i++ - } - return i - } - - // minNumPermissions is the minimum number of permissions that will always be present. - const minNumPermissions = 2 - - userRows := returnedRows() - if userRows < minNumPermissions { - t.Fatalf("did not get expected number of rows, got %d", userRows) - } - - resp, err = b.HandleRequest(&logical.Request{ - Operation: logical.RevokeOperation, - Storage: s, - Secret: &logical.Secret{ - InternalData: map[string]interface{}{ - "secret_type": "creds", - "username": d.Username, - "role": name, - }, - }, - }) - if err != nil { - return err - } - if resp != nil { - if resp.IsError() { - return fmt.Errorf("Error on resp: %#v", *resp) - } - } - - userRows = returnedRows() - // User shouldn't exist so returnedRows() should encounter an error and exit with -1 - if userRows != -1 { - t.Fatalf("did not get expected number of rows, got %d", userRows) - } - - return nil - }, - } -} - -func testAccStepCreateTable(t *testing.T, b logical.Backend, s logical.Storage, name string, connURL string) logicaltest.TestStep { - return logicaltest.TestStep{ - Operation: logical.ReadOperation, - Path: path.Join("creds", name), - Check: func(resp *logical.Response) error { - var d struct { - Username string `mapstructure:"username"` - Password string `mapstructure:"password"` - } - if err := mapstructure.Decode(resp.Data, &d); err != nil { - return err - } - log.Printf("[TRACE] Generated credentials: %v", d) - conn, err := pq.ParseURL(connURL) - - if err != nil { - t.Fatal(err) - } - - conn += " timezone=utc" - - db, err := sql.Open("postgres", conn) - if err != nil { - t.Fatal(err) - } - - _, err = db.Exec("CREATE TABLE test (id SERIAL PRIMARY KEY);") - if err != nil { - t.Fatal(err) - } - - resp, err = b.HandleRequest(&logical.Request{ - Operation: logical.RevokeOperation, - Storage: s, - Secret: &logical.Secret{ - InternalData: map[string]interface{}{ - "secret_type": "creds", - "username": d.Username, - }, - }, - }) - if err != nil { - return err - } - if resp != nil { - if resp.IsError() { - return fmt.Errorf("Error on resp: %#v", *resp) - } - } - - return nil - }, - } -} - -func testAccStepDropTable(t *testing.T, b logical.Backend, s logical.Storage, name string, connURL string) logicaltest.TestStep { - return logicaltest.TestStep{ - Operation: logical.ReadOperation, - Path: path.Join("creds", name), - Check: func(resp *logical.Response) error { - var d struct { - Username string `mapstructure:"username"` - Password string `mapstructure:"password"` - } - if err := mapstructure.Decode(resp.Data, &d); err != nil { - return err - } - log.Printf("[TRACE] Generated credentials: %v", d) - conn, err := pq.ParseURL(connURL) - - if err != nil { - t.Fatal(err) - } - - conn += " timezone=utc" - - db, err := sql.Open("postgres", conn) - if err != nil { - t.Fatal(err) - } - - _, err = db.Exec("DROP TABLE test;") - if err != nil { - t.Fatal(err) - } - - resp, err = b.HandleRequest(&logical.Request{ - Operation: logical.RevokeOperation, - Storage: s, - Secret: &logical.Secret{ - InternalData: map[string]interface{}{ - "secret_type": "creds", - "username": d.Username, - }, - }, - }) - if err != nil { - return err - } - if resp != nil { - if resp.IsError() { - return fmt.Errorf("Error on resp: %#v", *resp) - } - } - - return nil - }, - } -} - -func testAccStepReadRole(t *testing.T, name string, sql string) logicaltest.TestStep { - return logicaltest.TestStep{ - Operation: logical.ReadOperation, - Path: "roles/" + name, - Check: func(resp *logical.Response) error { - if resp == nil { - if sql == "" { - return nil - } - - return fmt.Errorf("bad: %#v", resp) - } - - var d struct { - SQL string `mapstructure:"sql"` - } - if err := mapstructure.Decode(resp.Data, &d); err != nil { - return err - } - - if d.SQL != sql { - return fmt.Errorf("bad: %#v", resp) - } - - return nil - }, - } -} - -const testRole = ` -CREATE ROLE "{{name}}" WITH - LOGIN - PASSWORD '{{password}}' - VALID UNTIL '{{expiration}}'; -GRANT ALL PRIVILEGES ON ALL TABLES IN SCHEMA public TO "{{name}}"; -` - -const testReadOnlyRole = ` -CREATE ROLE "{{name}}" WITH - LOGIN - PASSWORD '{{password}}' - VALID UNTIL '{{expiration}}'; -GRANT SELECT ON ALL TABLES IN SCHEMA public TO "{{name}}"; -GRANT SELECT ON ALL SEQUENCES IN SCHEMA public TO "{{name}}"; -` - -const testBlockStatementRole = ` -DO $$ -BEGIN - IF NOT EXISTS (SELECT * FROM pg_catalog.pg_roles WHERE rolname='foo-role') THEN - CREATE ROLE "foo-role"; - CREATE SCHEMA IF NOT EXISTS foo AUTHORIZATION "foo-role"; - ALTER ROLE "foo-role" SET search_path = foo; - GRANT TEMPORARY ON DATABASE "postgres" TO "foo-role"; - GRANT ALL PRIVILEGES ON SCHEMA foo TO "foo-role"; - GRANT ALL PRIVILEGES ON ALL TABLES IN SCHEMA foo TO "foo-role"; - GRANT ALL PRIVILEGES ON ALL SEQUENCES IN SCHEMA foo TO "foo-role"; - GRANT ALL PRIVILEGES ON ALL FUNCTIONS IN SCHEMA foo TO "foo-role"; - END IF; -END -$$ - -CREATE ROLE "{{name}}" WITH LOGIN PASSWORD '{{password}}' VALID UNTIL '{{expiration}}'; -GRANT "foo-role" TO "{{name}}"; -ALTER ROLE "{{name}}" SET search_path = foo; -GRANT CONNECT ON DATABASE "postgres" TO "{{name}}"; -` - -var testBlockStatementRoleSlice = []string{ - ` -DO $$ -BEGIN - IF NOT EXISTS (SELECT * FROM pg_catalog.pg_roles WHERE rolname='foo-role') THEN - CREATE ROLE "foo-role"; - CREATE SCHEMA IF NOT EXISTS foo AUTHORIZATION "foo-role"; - ALTER ROLE "foo-role" SET search_path = foo; - GRANT TEMPORARY ON DATABASE "postgres" TO "foo-role"; - GRANT ALL PRIVILEGES ON SCHEMA foo TO "foo-role"; - GRANT ALL PRIVILEGES ON ALL TABLES IN SCHEMA foo TO "foo-role"; - GRANT ALL PRIVILEGES ON ALL SEQUENCES IN SCHEMA foo TO "foo-role"; - GRANT ALL PRIVILEGES ON ALL FUNCTIONS IN SCHEMA foo TO "foo-role"; - END IF; -END -$$ -`, - `CREATE ROLE "{{name}}" WITH LOGIN PASSWORD '{{password}}' VALID UNTIL '{{expiration}}';`, - `GRANT "foo-role" TO "{{name}}";`, - `ALTER ROLE "{{name}}" SET search_path = foo;`, - `GRANT CONNECT ON DATABASE "postgres" TO "{{name}}";`, -} - -const defaultRevocationSQL = ` -REVOKE ALL PRIVILEGES ON ALL TABLES IN SCHEMA public FROM {{name}}; -REVOKE ALL PRIVILEGES ON ALL SEQUENCES IN SCHEMA public FROM {{name}}; -REVOKE USAGE ON SCHEMA public FROM {{name}}; - -DROP ROLE IF EXISTS {{name}}; -` diff --git a/builtin/logical/database/dbs/mysql_test.go b/builtin/logical/database/dbs/mysql_test.go index c489ffaea389..f4d1247023c1 100644 --- a/builtin/logical/database/dbs/mysql_test.go +++ b/builtin/logical/database/dbs/mysql_test.go @@ -24,7 +24,7 @@ func prepareMySQLTestContainer(t *testing.T) (cid dockertest.ContainerID, retURL // containers, so don't. dockertest.BindDockerToLocalhost = "yep" - testImagePull.Do(func() { + testMySQLImagePull.Do(func() { dockertest.Pull("mysql") }) diff --git a/builtin/logical/database/dbs/plugin.go b/builtin/logical/database/dbs/plugin.go index 45c815a45d93..1213a3677063 100644 --- a/builtin/logical/database/dbs/plugin.go +++ b/builtin/logical/database/dbs/plugin.go @@ -6,12 +6,12 @@ import ( "fmt" "net/rpc" "os/exec" + "strings" "sync" "time" "github.com/hashicorp/go-plugin" "github.com/hashicorp/vault/helper/pluginutil" - "github.com/hashicorp/vault/logical" ) // handshakeConfigs are used to just do a basic handshake between @@ -55,7 +55,7 @@ func (dc *DatabasePluginClient) Close() error { // newPluginClient returns a databaseRPCClient with a connection to a running // plugin. The client is wrapped in a DatabasePluginClient object to ensure the // plugin is killed on call of Close(). -func newPluginClient(sys logical.SystemView, command, checksum string) (DatabaseType, error) { +func newPluginClient(sys pluginutil.Wrapper, command, checksum string) (DatabaseType, error) { // pluginMap is the map of plugins we can dispense. var pluginMap = map[string]plugin.Plugin{ "database": new(DatabasePlugin), @@ -81,7 +81,8 @@ func newPluginClient(sys logical.SystemView, command, checksum string) (Database } // Add the response wrap token to the ENV of the plugin - cmd := exec.Command(command) + commandArr := strings.Split(command, " ") + cmd := exec.Command(commandArr[0], commandArr[1]) cmd.Env = append(cmd.Env, fmt.Sprintf("%s=%s", pluginutil.PluginUnwrapTokenEnv, wrapToken)) checksumDecoded, err := hex.DecodeString(checksum) @@ -265,7 +266,7 @@ func (ds *databasePluginRPCServer) Initialize(args map[string]interface{}, _ *st return err } -func (ds *databasePluginRPCServer) Close(_ interface{}, _ *struct{}) error { +func (ds *databasePluginRPCServer) Close(_ struct{}, _ *struct{}) error { ds.impl.Close() return nil } diff --git a/builtin/logical/database/dbs/plugin_test.go b/builtin/logical/database/dbs/plugin_test.go new file mode 100644 index 000000000000..74e103c4ae6e --- /dev/null +++ b/builtin/logical/database/dbs/plugin_test.go @@ -0,0 +1,325 @@ +package dbs + +import ( + "crypto/sha256" + "encoding/hex" + "errors" + "fmt" + "io" + "net" + "os" + "os/exec" + "sync" + "testing" + "time" + + "github.com/hashicorp/vault/helper/pluginutil" + "github.com/hashicorp/vault/http" + "github.com/hashicorp/vault/logical" + "github.com/hashicorp/vault/vault" +) + +var ( + testPluginImagePull sync.Once +) + +type mockPlugin struct { + users map[string][]string + CredentialsProducer +} + +func (m *mockPlugin) Type() string { return "mock" } +func (m *mockPlugin) CreateUser(statements Statements, username, password, expiration string) error { + err := errors.New("err") + if username == "" || password == "" || expiration == "" { + return err + } + + if _, ok := m.users[username]; ok { + return err + } + + m.users[username] = []string{password, expiration} + + return nil +} +func (m *mockPlugin) RenewUser(statements Statements, username, expiration string) error { + err := errors.New("err") + if username == "" || expiration == "" { + return err + } + + if _, ok := m.users[username]; !ok { + return err + } + + return nil +} +func (m *mockPlugin) RevokeUser(statements Statements, username string) error { + err := errors.New("err") + if username == "" { + return err + } + + if _, ok := m.users[username]; !ok { + return err + } + + delete(m.users, username) + return nil +} +func (m *mockPlugin) Initialize(conf map[string]interface{}) error { + err := errors.New("err") + if len(conf) != 1 { + return err + } + + return nil +} +func (m *mockPlugin) Close() error { + m.users = nil + return nil +} + +func getConf(t *testing.T) *DatabaseConfig { + command := fmt.Sprintf("%s -test.run=TestPlugin_Main", os.Args[0]) + cmd := exec.Command(os.Args[0]) + hash := sha256.New() + + file, err := os.Open(cmd.Path) + if err != nil { + t.Fatal(err) + } + defer file.Close() + + _, err = io.Copy(hash, file) + if err != nil { + t.Fatal(err) + } + + sum := hash.Sum(nil) + + conf := &DatabaseConfig{ + DatabaseType: pluginTypeName, + PluginCommand: command, + PluginChecksum: hex.EncodeToString(sum), + ConnectionDetails: map[string]interface{}{ + "test": true, + }, + } + + return conf +} + +func getCore(t *testing.T) (*vault.Core, net.Listener, logical.SystemView) { + core, _, _, ln := vault.TestCoreUnsealedWithListener(t) + http.TestServerWithListener(t, ln, "", core) + sys := vault.TestDynamicSystemView(core) + + return core, ln, sys +} + +func TestPlugin_Main(t *testing.T) { + if os.Getenv(pluginutil.PluginUnwrapTokenEnv) == "" { + return + } + + plugin := &mockPlugin{ + users: make(map[string][]string), + CredentialsProducer: &sqlCredentialsProducer{5, 50}, + } + + NewPluginServer(plugin) +} + +func TestPlugin_Initialize(t *testing.T) { + _, ln, sys := getCore(t) + defer ln.Close() + + conf := getConf(t) + dbRaw, err := PluginFactory(conf, sys) + if err != nil { + t.Fatalf("err: %s", err) + } + + err = dbRaw.Initialize(conf.ConnectionDetails) + if err != nil { + t.Fatalf("err: %s", err) + } + + err = dbRaw.Close() + if err != nil { + t.Fatalf("err: %s", err) + } +} + +func TestPlugin_CreateUser(t *testing.T) { + _, ln, sys := getCore(t) + defer ln.Close() + + conf := getConf(t) + db, err := PluginFactory(conf, sys) + if err != nil { + t.Fatalf("err: %s", err) + } + defer db.Close() + + err = db.Initialize(conf.ConnectionDetails) + if err != nil { + t.Fatalf("err: %s", err) + } + + username, err := db.GenerateUsername("test") + if err != nil { + t.Fatalf("err: %s", err) + } + + password, err := db.GeneratePassword() + if err != nil { + t.Fatalf("err: %s", err) + } + + expiration, err := db.GenerateExpiration(time.Minute) + if err != nil { + t.Fatalf("err: %s", err) + } + + err = db.CreateUser(Statements{}, username, password, expiration) + if err != nil { + t.Fatalf("err: %s", err) + } + // try and save the same user again to verify it saved the first time, this + // should return an error + err = db.CreateUser(Statements{}, username, password, expiration) + if err == nil { + t.Fatal("expected an error, user wasn't created correctly") + } + + // Create one more user + username, err = db.GenerateUsername("test") + if err != nil { + t.Fatalf("err: %s", err) + } + + err = db.CreateUser(Statements{}, username, password, expiration) + if err != nil { + t.Fatalf("err: %s", err) + } +} + +func TestPlugin_RenewUser(t *testing.T) { + _, ln, sys := getCore(t) + defer ln.Close() + + conf := getConf(t) + db, err := PluginFactory(conf, sys) + if err != nil { + t.Fatalf("err: %s", err) + } + defer db.Close() + + err = db.Initialize(conf.ConnectionDetails) + if err != nil { + t.Fatalf("err: %s", err) + } + + username, err := db.GenerateUsername("test") + if err != nil { + t.Fatalf("err: %s", err) + } + + password, err := db.GeneratePassword() + if err != nil { + t.Fatalf("err: %s", err) + } + + expiration, err := db.GenerateExpiration(time.Minute) + if err != nil { + t.Fatalf("err: %s", err) + } + + err = db.CreateUser(Statements{}, username, password, expiration) + if err != nil { + t.Fatalf("err: %s", err) + } + + expiration, err = db.GenerateExpiration(time.Minute) + if err != nil { + t.Fatalf("err: %s", err) + } + err = db.RenewUser(Statements{}, username, expiration) + if err != nil { + t.Fatalf("err: %s", err) + } +} + +func TestPlugin_RevokeUser(t *testing.T) { + _, ln, sys := getCore(t) + defer ln.Close() + + conf := getConf(t) + db, err := PluginFactory(conf, sys) + if err != nil { + t.Fatalf("err: %s", err) + } + defer db.Close() + + err = db.Initialize(conf.ConnectionDetails) + if err != nil { + t.Fatalf("err: %s", err) + } + + username, err := db.GenerateUsername("test") + if err != nil { + t.Fatalf("err: %s", err) + } + + password, err := db.GeneratePassword() + if err != nil { + t.Fatalf("err: %s", err) + } + + expiration, err := db.GenerateExpiration(time.Minute) + if err != nil { + t.Fatalf("err: %s", err) + } + + err = db.CreateUser(Statements{}, username, password, expiration) + if err != nil { + t.Fatalf("err: %s", err) + } + + // Test default revoke statememts + err = db.RevokeUser(Statements{}, username) + if err != nil { + t.Fatalf("err: %s", err) + } + + // Try adding the same username back so we can verify it was removed + err = db.CreateUser(Statements{}, username, password, expiration) + if err != nil { + t.Fatalf("err: %s", err) + } + + username, err = db.GenerateUsername("test") + if err != nil { + t.Fatalf("err: %s", err) + } + + expiration, err = db.GenerateExpiration(time.Minute) + if err != nil { + t.Fatalf("err: %s", err) + } + + // try once more + err = db.CreateUser(Statements{}, username, password, expiration) + if err != nil { + t.Fatalf("err: %s", err) + } + + err = db.RevokeUser(Statements{}, username) + if err != nil { + t.Fatalf("err: %s", err) + } + +} diff --git a/builtin/logical/database/dbs/postgresql_test.go b/builtin/logical/database/dbs/postgresql_test.go index 211ab0254186..dab720920b8b 100644 --- a/builtin/logical/database/dbs/postgresql_test.go +++ b/builtin/logical/database/dbs/postgresql_test.go @@ -11,10 +11,10 @@ import ( ) var ( - testImagePull sync.Once + testPostgresImagePull sync.Once ) -func prepareTestContainer(t *testing.T) (cid dockertest.ContainerID, retURL string) { +func preparePostgresTestContainer(t *testing.T) (cid dockertest.ContainerID, retURL string) { if os.Getenv("PG_URL") != "" { return "", os.Getenv("PG_URL") } @@ -24,7 +24,7 @@ func prepareTestContainer(t *testing.T) (cid dockertest.ContainerID, retURL stri // containers, so don't. dockertest.BindDockerToLocalhost = "yep" - testImagePull.Do(func() { + testPostgresImagePull.Do(func() { dockertest.Pull("postgres") }) @@ -65,7 +65,7 @@ func cleanupTestContainer(t *testing.T, cid dockertest.ContainerID) { } func TestPostgreSQL_Initialize(t *testing.T) { - cid, connURL := prepareTestContainer(t) + cid, connURL := preparePostgresTestContainer(t) if cid != "" { defer cleanupTestContainer(t, cid) } @@ -107,7 +107,7 @@ func TestPostgreSQL_Initialize(t *testing.T) { } func TestPostgreSQL_CreateUser(t *testing.T) { - cid, connURL := prepareTestContainer(t) + cid, connURL := preparePostgresTestContainer(t) if cid != "" { defer cleanupTestContainer(t, cid) } @@ -150,7 +150,7 @@ func TestPostgreSQL_CreateUser(t *testing.T) { } statements := Statements{ - CreationStatements: testRole, + CreationStatements: testPostgresRole, } err = db.CreateUser(statements, username, password, expiration) @@ -172,7 +172,7 @@ func TestPostgreSQL_CreateUser(t *testing.T) { if err != nil { t.Fatalf("err: %s", err) } - statements.CreationStatements = testReadOnlyRole + statements.CreationStatements = testPostgresReadOnlyRole err = db.CreateUser(statements, username, password, expiration) if err != nil { t.Fatalf("err: %s", err) @@ -200,7 +200,7 @@ func TestPostgreSQL_CreateUser(t *testing.T) { } func TestPostgreSQL_RenewUser(t *testing.T) { - cid, connURL := prepareTestContainer(t) + cid, connURL := preparePostgresTestContainer(t) if cid != "" { defer cleanupTestContainer(t, cid) } @@ -237,7 +237,7 @@ func TestPostgreSQL_RenewUser(t *testing.T) { } statements := Statements{ - CreationStatements: testRole, + CreationStatements: testPostgresRole, } err = db.CreateUser(statements, username, password, expiration) @@ -256,7 +256,7 @@ func TestPostgreSQL_RenewUser(t *testing.T) { } func TestPostgreSQL_RevokeUser(t *testing.T) { - cid, connURL := prepareTestContainer(t) + cid, connURL := preparePostgresTestContainer(t) if cid != "" { defer cleanupTestContainer(t, cid) } @@ -293,7 +293,7 @@ func TestPostgreSQL_RevokeUser(t *testing.T) { } statements := Statements{ - CreationStatements: testRole, + CreationStatements: testPostgresRole, } err = db.CreateUser(statements, username, password, expiration) @@ -333,7 +333,7 @@ func TestPostgreSQL_RevokeUser(t *testing.T) { } // Test custom revoke statements - statements.RevocationStatements = defaultRevocationSQL + statements.RevocationStatements = defaultPostgresRevocationSQL err = db.RevokeUser(statements, username) if err != nil { t.Fatalf("err: %s", err) @@ -341,7 +341,7 @@ func TestPostgreSQL_RevokeUser(t *testing.T) { } -const testRole = ` +const testPostgresRole = ` CREATE ROLE "{{name}}" WITH LOGIN PASSWORD '{{password}}' @@ -349,7 +349,7 @@ CREATE ROLE "{{name}}" WITH GRANT ALL PRIVILEGES ON ALL TABLES IN SCHEMA public TO "{{name}}"; ` -const testReadOnlyRole = ` +const testPostgresReadOnlyRole = ` CREATE ROLE "{{name}}" WITH LOGIN PASSWORD '{{password}}' @@ -358,7 +358,7 @@ GRANT SELECT ON ALL TABLES IN SCHEMA public TO "{{name}}"; GRANT SELECT ON ALL SEQUENCES IN SCHEMA public TO "{{name}}"; ` -const testBlockStatementRole = ` +const testPostgresBlockStatementRole = ` DO $$ BEGIN IF NOT EXISTS (SELECT * FROM pg_catalog.pg_roles WHERE rolname='foo-role') THEN @@ -380,7 +380,7 @@ ALTER ROLE "{{name}}" SET search_path = foo; GRANT CONNECT ON DATABASE "postgres" TO "{{name}}"; ` -var testBlockStatementRoleSlice = []string{ +var testPostgresBlockStatementRoleSlice = []string{ ` DO $$ BEGIN @@ -403,7 +403,7 @@ $$ `GRANT CONNECT ON DATABASE "postgres" TO "{{name}}";`, } -const defaultRevocationSQL = ` +const defaultPostgresRevocationSQL = ` REVOKE ALL PRIVILEGES ON ALL TABLES IN SCHEMA public FROM "{{name}}"; REVOKE ALL PRIVILEGES ON ALL SEQUENCES IN SCHEMA public FROM "{{name}}"; REVOKE USAGE ON SCHEMA public FROM "{{name}}"; diff --git a/helper/pluginutil/tls.go b/helper/pluginutil/tls.go index 88d88689d1ae..08f24985d225 100644 --- a/helper/pluginutil/tls.go +++ b/helper/pluginutil/tls.go @@ -21,7 +21,6 @@ import ( "github.com/hashicorp/errwrap" uuid "github.com/hashicorp/go-uuid" "github.com/hashicorp/vault/api" - "github.com/hashicorp/vault/logical" ) var ( @@ -30,6 +29,10 @@ var ( PluginUnwrapTokenEnv = "VAULT_UNWRAP_TOKEN" ) +type Wrapper interface { + ResponseWrapData(data map[string]interface{}, ttl time.Duration, jwt bool) (string, error) +} + // GenerateCACert returns a CA cert used to later sign the certificates for the // plugin client and server. func GenerateCACert() ([]byte, *x509.Certificate, *ecdsa.PrivateKey, error) { @@ -147,7 +150,7 @@ func CreateClientTLSConfig(CACert *x509.Certificate, CAKey *ecdsa.PrivateKey) (* // WrapServerConfig is used to create a server certificate and private key, then // wrap them in an unwrap token for later retrieval by the plugin. -func WrapServerConfig(sys logical.SystemView, CACertBytes []byte, CACert *x509.Certificate, CAKey *ecdsa.PrivateKey) (string, error) { +func WrapServerConfig(sys Wrapper, CACertBytes []byte, CACert *x509.Certificate, CAKey *ecdsa.PrivateKey) (string, error) { serverCertBytes, _, serverKey, err := generateSignedCert(CACert, CAKey) if err != nil { return "", err diff --git a/vault/testing.go b/vault/testing.go index b567fe75e526..7b914bbdbbee 100644 --- a/vault/testing.go +++ b/vault/testing.go @@ -231,6 +231,18 @@ func TestCoreUnsealedBackend(t testing.TB, backend physical.Backend) (*Core, [][ return core, keys, token } +func TestCoreUnsealedWithListener(t testing.TB) (*Core, [][]byte, string, net.Listener) { + core, keys, token := TestCoreUnsealed(t) + ln, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("err: %s", err) + } + addr := "http://" + ln.Addr().String() + core.redirectAddr = addr + + return core, keys, token, ln +} + func testTokenStore(t testing.TB, c *Core) *TokenStore { me := &MountEntry{ Table: credentialTableType, @@ -293,6 +305,10 @@ func TestKeyCopy(key []byte) []byte { return result } +func TestDynamicSystemView(c *Core) *dynamicSystemView { + return &dynamicSystemView{c, nil} +} + var testLogicalBackends = map[string]logical.Factory{} // Starts the test server which responds to SSH authentication. From ca026c6cfd37cb28390be5120f4865e452cd60b5 Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Mon, 27 Mar 2017 11:46:20 -0700 Subject: [PATCH 040/152] Remove the unused sync.Once object --- builtin/logical/database/dbs/plugin_test.go | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/builtin/logical/database/dbs/plugin_test.go b/builtin/logical/database/dbs/plugin_test.go index 74e103c4ae6e..151d0c88ffa1 100644 --- a/builtin/logical/database/dbs/plugin_test.go +++ b/builtin/logical/database/dbs/plugin_test.go @@ -9,7 +9,6 @@ import ( "net" "os" "os/exec" - "sync" "testing" "time" @@ -19,10 +18,6 @@ import ( "github.com/hashicorp/vault/vault" ) -var ( - testPluginImagePull sync.Once -) - type mockPlugin struct { users map[string][]string CredentialsProducer @@ -119,6 +114,8 @@ func getCore(t *testing.T) (*vault.Core, net.Listener, logical.SystemView) { return core, ln, sys } +// This is not an actual test case, it's a helper function that will be executed +// by the go-plugin client via an exec call. func TestPlugin_Main(t *testing.T) { if os.Getenv(pluginutil.PluginUnwrapTokenEnv) == "" { return From b2c4555c1fa9419a6dcee465f0bf21b410240f81 Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Mon, 27 Mar 2017 15:17:28 -0700 Subject: [PATCH 041/152] Wrap the database calls with tracing information --- ...icsmiddleware.go => databasemiddleware.go} | 104 ++++++++++++++++++ builtin/logical/database/dbs/db.go | 21 +++- builtin/logical/database/dbs/plugin.go | 7 +- builtin/logical/database/dbs/postgresql.go | 6 - .../database/path_config_connection.go | 4 +- 5 files changed, 130 insertions(+), 12 deletions(-) rename builtin/logical/database/dbs/{metricsmiddleware.go => databasemiddleware.go} (60%) diff --git a/builtin/logical/database/dbs/metricsmiddleware.go b/builtin/logical/database/dbs/databasemiddleware.go similarity index 60% rename from builtin/logical/database/dbs/metricsmiddleware.go rename to builtin/logical/database/dbs/databasemiddleware.go index 61b4bd4ebca9..d3f037ecbca3 100644 --- a/builtin/logical/database/dbs/metricsmiddleware.go +++ b/builtin/logical/database/dbs/databasemiddleware.go @@ -4,8 +4,112 @@ import ( "time" metrics "github.com/armon/go-metrics" + log "github.com/mgutz/logxi/v1" ) +// ---- Tracing Middleware Domain ---- + +type databaseTracingMiddleware struct { + next DatabaseType + logger log.Logger + + typeStr string +} + +func (mw *databaseTracingMiddleware) Type() string { + return mw.next.Type() +} + +func (mw *databaseTracingMiddleware) CreateUser(statements Statements, username, password, expiration string) (err error) { + if mw.logger.IsTrace() { + defer func(then time.Time) { + mw.logger.Trace("database/CreateUser: finished", "type", mw.typeStr, "err", err, "took", time.Since(then)) + }(time.Now()) + + mw.logger.Trace("database/CreateUser: starting", "type", mw.typeStr) + } + return mw.next.CreateUser(statements, username, password, expiration) +} + +func (mw *databaseTracingMiddleware) RenewUser(statements Statements, username, expiration string) (err error) { + if mw.logger.IsTrace() { + defer func(then time.Time) { + mw.logger.Trace("database/RenewUser: finished", "type", mw.typeStr, "err", err, "took", time.Since(then)) + }(time.Now()) + + mw.logger.Trace("database/RenewUser: starting", "type", mw.typeStr) + } + return mw.next.RenewUser(statements, username, expiration) +} + +func (mw *databaseTracingMiddleware) RevokeUser(statements Statements, username string) (err error) { + if mw.logger.IsTrace() { + defer func(then time.Time) { + mw.logger.Trace("database/RevokeUser: finished", "type", mw.typeStr, "err", err, "took", time.Since(then)) + }(time.Now()) + + mw.logger.Trace("database/RevokeUser: starting", "type", mw.typeStr) + } + return mw.next.RevokeUser(statements, username) +} + +func (mw *databaseTracingMiddleware) Initialize(conf map[string]interface{}) (err error) { + if mw.logger.IsTrace() { + defer func(then time.Time) { + mw.logger.Trace("database/Initialize: finished", "type", mw.typeStr, "err", err, "took", time.Since(then)) + }(time.Now()) + + mw.logger.Trace("database/Initialize: starting", "type", mw.typeStr) + } + return mw.next.Initialize(conf) +} + +func (mw *databaseTracingMiddleware) Close() (err error) { + if mw.logger.IsTrace() { + defer func(then time.Time) { + mw.logger.Trace("database/Close: finished", "type", mw.typeStr, "err", err, "took", time.Since(then)) + }(time.Now()) + + mw.logger.Trace("database/Close: starting", "type", mw.typeStr) + } + return mw.next.Close() +} + +func (mw *databaseTracingMiddleware) GenerateUsername(displayName string) (_ string, err error) { + if mw.logger.IsTrace() { + defer func(then time.Time) { + mw.logger.Trace("database/GenerateUsername: finished", "type", mw.typeStr, "err", err, "took", time.Since(then)) + }(time.Now()) + + mw.logger.Trace("database/GenerateUsername: starting", "type", mw.typeStr) + } + return mw.next.GenerateUsername(displayName) +} + +func (mw *databaseTracingMiddleware) GeneratePassword() (_ string, err error) { + if mw.logger.IsTrace() { + defer func(then time.Time) { + mw.logger.Trace("database/GeneratePassword: finished", "type", mw.typeStr, "err", err, "took", time.Since(then)) + }(time.Now()) + + mw.logger.Trace("database/GeneratePassword: starting", "type", mw.typeStr) + } + return mw.next.GeneratePassword() +} + +func (mw *databaseTracingMiddleware) GenerateExpiration(duration time.Duration) (_ string, err error) { + if mw.logger.IsTrace() { + defer func(then time.Time) { + mw.logger.Trace("database/GenerateExpiration: finished", "type", mw.typeStr, "err", err, "took", time.Since(then)) + }(time.Now()) + + mw.logger.Trace("database/GenerateExpiration: starting", "type", mw.typeStr) + } + return mw.next.GenerateExpiration(duration) +} + +// ---- Metrics Middleware Domain ---- + type databaseMetricsMiddleware struct { next DatabaseType diff --git a/builtin/logical/database/dbs/db.go b/builtin/logical/database/dbs/db.go index 4554963ac7f4..54581e465d0d 100644 --- a/builtin/logical/database/dbs/db.go +++ b/builtin/logical/database/dbs/db.go @@ -7,6 +7,7 @@ import ( "time" "github.com/hashicorp/vault/logical" + log "github.com/mgutz/logxi/v1" ) const ( @@ -21,9 +22,9 @@ var ( ErrEmptyCreationStatement = errors.New("Empty creation statements") ) -type Factory func(*DatabaseConfig, logical.SystemView) (DatabaseType, error) +type Factory func(*DatabaseConfig, logical.SystemView, log.Logger) (DatabaseType, error) -func BuiltinFactory(conf *DatabaseConfig, sys logical.SystemView) (DatabaseType, error) { +func BuiltinFactory(conf *DatabaseConfig, sys logical.SystemView, logger log.Logger) (DatabaseType, error) { var dbType DatabaseType switch conf.DatabaseType { @@ -76,10 +77,17 @@ func BuiltinFactory(conf *DatabaseConfig, sys logical.SystemView) (DatabaseType, typeStr: dbType.Type(), } + // Wrap with tracing middleware + dbType = &databaseTracingMiddleware{ + next: dbType, + typeStr: dbType.Type(), + logger: logger, + } + return dbType, nil } -func PluginFactory(conf *DatabaseConfig, sys logical.SystemView) (DatabaseType, error) { +func PluginFactory(conf *DatabaseConfig, sys logical.SystemView, logger log.Logger) (DatabaseType, error) { if conf.PluginCommand == "" { return nil, errors.New("ERROR") } @@ -99,6 +107,13 @@ func PluginFactory(conf *DatabaseConfig, sys logical.SystemView) (DatabaseType, typeStr: db.Type(), } + // Wrap with tracing middleware + db = &databaseTracingMiddleware{ + next: db, + typeStr: db.Type(), + logger: logger, + } + return db, nil } diff --git a/builtin/logical/database/dbs/plugin.go b/builtin/logical/database/dbs/plugin.go index 1213a3677063..b1f9abe20cea 100644 --- a/builtin/logical/database/dbs/plugin.go +++ b/builtin/logical/database/dbs/plugin.go @@ -82,7 +82,12 @@ func newPluginClient(sys pluginutil.Wrapper, command, checksum string) (Database // Add the response wrap token to the ENV of the plugin commandArr := strings.Split(command, " ") - cmd := exec.Command(commandArr[0], commandArr[1]) + var cmd *exec.Cmd + if len(commandArr) > 1 { + cmd = exec.Command(commandArr[0], commandArr[1]) + } else { + cmd = exec.Command(commandArr[0]) + } cmd.Env = append(cmd.Env, fmt.Sprintf("%s=%s", pluginutil.PluginUnwrapTokenEnv, wrapToken)) checksumDecoded, err := hex.DecodeString(checksum) diff --git a/builtin/logical/database/dbs/postgresql.go b/builtin/logical/database/dbs/postgresql.go index 20d548f9204f..c8ba110cf7a8 100644 --- a/builtin/logical/database/dbs/postgresql.go +++ b/builtin/logical/database/dbs/postgresql.go @@ -43,13 +43,11 @@ func (p *PostgreSQL) CreateUser(statements Statements, username, password, expir } // Start a transaction - // b.logger.Trace("postgres/pathRoleCreateRead: starting transaction") tx, err := db.Begin() if err != nil { return err } defer func() { - // b.logger.Trace("postgres/pathRoleCreateRead: rolling back transaction") tx.Rollback() }() // Return the secret @@ -61,7 +59,6 @@ func (p *PostgreSQL) CreateUser(statements Statements, username, password, expir continue } - // b.logger.Trace("postgres/pathRoleCreateRead: preparing statement") stmt, err := tx.Prepare(queryHelper(query, map[string]string{ "name": username, "password": password, @@ -71,15 +68,12 @@ func (p *PostgreSQL) CreateUser(statements Statements, username, password, expir return err } defer stmt.Close() - // b.logger.Trace("postgres/pathRoleCreateRead: executing statement") if _, err := stmt.Exec(); err != nil { return err } } // Commit the transaction - - // b.logger.Trace("postgres/pathRoleCreateRead: committing transaction") if err := tx.Commit(); err != nil { return err } diff --git a/builtin/logical/database/path_config_connection.go b/builtin/logical/database/path_config_connection.go index ba6a3780509a..1b0878670e6e 100644 --- a/builtin/logical/database/path_config_connection.go +++ b/builtin/logical/database/path_config_connection.go @@ -63,7 +63,7 @@ func (b *databaseBackend) pathConnectionReset(req *logical.Request, data *framew factory := config.GetFactory() - db, err = factory(&config, b.System()) + db, err = factory(&config, b.System(), b.logger) if err != nil { return logical.ErrorResponse(fmt.Sprintf("Error creating database object: %s", err)), nil } @@ -262,7 +262,7 @@ func (b *databaseBackend) connectionWriteHandler(factory dbs.Factory) framework. b.Lock() defer b.Unlock() - db, err := factory(config, b.System()) + db, err := factory(config, b.System(), b.logger) if err != nil { return logical.ErrorResponse(fmt.Sprintf("Error creating database object: %s", err)), nil } From d93378bb29836328d49738d1204a1476a7b8791a Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Tue, 28 Mar 2017 10:04:42 -0700 Subject: [PATCH 042/152] Fix for checking types of database on update --- builtin/logical/database/path_config_connection.go | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/builtin/logical/database/path_config_connection.go b/builtin/logical/database/path_config_connection.go index 1b0878670e6e..ff633e74565e 100644 --- a/builtin/logical/database/path_config_connection.go +++ b/builtin/logical/database/path_config_connection.go @@ -279,6 +279,8 @@ func (b *databaseBackend) connectionWriteHandler(factory dbs.Factory) framework. } if _, ok := b.connections[name]; ok { + newType := db.Type() + // Don't update connection until the reset api is hit, close for // now. err = db.Close() @@ -287,7 +289,7 @@ func (b *databaseBackend) connectionWriteHandler(factory dbs.Factory) framework. } // Don't allow the connection type to change - if b.connections[name].Type() != connType { + if b.connections[name].Type() != newType { return logical.ErrorResponse("Can not change type of existing connection."), nil } } else { From 6de5cfad5e5796c160833a27f11d809fae2cf96c Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Tue, 28 Mar 2017 11:30:45 -0700 Subject: [PATCH 043/152] Add functionaility to build db objects from disk so restarts work --- builtin/logical/database/backend.go | 46 +++++++++++++++++-- .../database/dbs/connectionproducer.go | 8 ++-- builtin/logical/database/dbs/db.go | 5 +- .../database/path_config_connection.go | 37 +++------------ builtin/logical/database/path_role_create.go | 24 ++++------ builtin/logical/database/path_roles.go | 7 +++ builtin/logical/database/secret_creds.go | 28 +++++------ helper/pluginutil/tls.go | 3 ++ 8 files changed, 88 insertions(+), 70 deletions(-) diff --git a/builtin/logical/database/backend.go b/builtin/logical/database/backend.go index 6108652532d8..f8bcc60f1d0a 100644 --- a/builtin/logical/database/backend.go +++ b/builtin/logical/database/backend.go @@ -1,6 +1,7 @@ package database import ( + "fmt" "strings" "sync" @@ -52,14 +53,11 @@ type databaseBackend struct { logger log.Logger *framework.Backend - sync.RWMutex + sync.Mutex } // resetAllDBs closes all connections from all database types func (b *databaseBackend) closeAllDBs() { - b.logger.Trace("postgres/resetdb: enter") - defer b.logger.Trace("postgres/resetdb: exit") - b.Lock() defer b.Unlock() @@ -68,6 +66,46 @@ func (b *databaseBackend) closeAllDBs() { } } +// This function is used to retrieve a database object either from the cached +// connection map or by using the database config in storage. The caller of this +// function needs to hold the backend's lock. +func (b *databaseBackend) getOrCreateDBObj(s logical.Storage, name string) (dbs.DatabaseType, error) { + // if the object already is built and cached, return it + db, ok := b.connections[name] + if ok { + return db, nil + } + + entry, err := s.Get(fmt.Sprintf("dbs/%s", name)) + if err != nil { + return nil, fmt.Errorf("failed to read connection configuration with name: %s", name) + } + if entry == nil { + return nil, fmt.Errorf("failed to find entry for connection with name: %s", name) + } + + var config dbs.DatabaseConfig + if err := entry.DecodeJSON(&config); err != nil { + return nil, err + } + + factory := config.GetFactory() + + db, err = factory(&config, b.System(), b.logger) + if err != nil { + return nil, err + } + + err = db.Initialize(config.ConnectionDetails) + if err != nil { + return nil, err + } + + b.connections[name] = db + + return db, nil +} + func (b *databaseBackend) Role(s logical.Storage, n string) (*roleEntry, error) { entry, err := s.Get("role/" + n) if err != nil { diff --git a/builtin/logical/database/dbs/connectionproducer.go b/builtin/logical/database/dbs/connectionproducer.go index 1e944c7b964b..dae8d9400e23 100644 --- a/builtin/logical/database/dbs/connectionproducer.go +++ b/builtin/logical/database/dbs/connectionproducer.go @@ -20,7 +20,7 @@ import ( ) var ( - errNotInitalized = errors.New("Connection has not been initalized") + errNotInitalized = errors.New("connection has not been initalized") ) type ConnectionProducer interface { @@ -142,7 +142,7 @@ func (c *cassandraConnectionProducer) Initialize(conf map[string]interface{}) er c.initalized = true if _, err := c.connection(); err != nil { - return fmt.Errorf("Error Initalizing Connection: %s", err) + return fmt.Errorf("error Initalizing Connection: %s", err) } return nil @@ -244,7 +244,7 @@ func (c *cassandraConnectionProducer) createSession() (*gocql.Session, error) { session, err := clusterConfig.CreateSession() if err != nil { - return nil, fmt.Errorf("Error creating session: %s", err) + return nil, fmt.Errorf("error creating session: %s", err) } // Set consistency @@ -260,7 +260,7 @@ func (c *cassandraConnectionProducer) createSession() (*gocql.Session, error) { // Verify the info err = session.Query(`LIST USERS`).Exec() if err != nil { - return nil, fmt.Errorf("Error validating connection info: %s", err) + return nil, fmt.Errorf("error validating connection info: %s", err) } return session, nil diff --git a/builtin/logical/database/dbs/db.go b/builtin/logical/database/dbs/db.go index 54581e465d0d..74f5a2605718 100644 --- a/builtin/logical/database/dbs/db.go +++ b/builtin/logical/database/dbs/db.go @@ -18,10 +18,11 @@ const ( ) var ( - ErrUnsupportedDatabaseType = errors.New("Unsupported database type") - ErrEmptyCreationStatement = errors.New("Empty creation statements") + ErrUnsupportedDatabaseType = errors.New("unsupported database type") + ErrEmptyCreationStatement = errors.New("empty creation statements") ) +// Factory function for type Factory func(*DatabaseConfig, logical.SystemView, log.Logger) (DatabaseType, error) func BuiltinFactory(conf *DatabaseConfig, sys logical.SystemView, logger log.Logger) (DatabaseType, error) { diff --git a/builtin/logical/database/path_config_connection.go b/builtin/logical/database/path_config_connection.go index ff633e74565e..b4c699750d55 100644 --- a/builtin/logical/database/path_config_connection.go +++ b/builtin/logical/database/path_config_connection.go @@ -1,7 +1,6 @@ package database import ( - "errors" "fmt" "strings" "time" @@ -34,47 +33,24 @@ func pathResetConnection(b *databaseBackend) *framework.Path { func (b *databaseBackend) pathConnectionReset(req *logical.Request, data *framework.FieldData) (*logical.Response, error) { name := data.Get("name").(string) if name == "" { - return nil, errors.New("No database name set") + return logical.ErrorResponse("Empty name attribute given"), nil } // Grab the mutex lock b.Lock() defer b.Unlock() - entry, err := req.Storage.Get(fmt.Sprintf("dbs/%s", name)) - if err != nil { - return nil, fmt.Errorf("failed to read connection configuration") - } - if entry == nil { - return nil, nil - } - - var config dbs.DatabaseConfig - if err := entry.DecodeJSON(&config); err != nil { - return nil, err - } - db, ok := b.connections[name] - if !ok { - return logical.ErrorResponse("Can not change type of existing connection."), nil - } - - db.Close() - - factory := config.GetFactory() - - db, err = factory(&config, b.System(), b.logger) - if err != nil { - return logical.ErrorResponse(fmt.Sprintf("Error creating database object: %s", err)), nil + if ok { + db.Close() + delete(b.connections, name) } - err = db.Initialize(config.ConnectionDetails) + db, err := b.getOrCreateDBObj(req.Storage, name) if err != nil { - return logical.ErrorResponse(fmt.Sprintf("Error creating database object: %s", err)), nil + return nil, err } - b.connections[name] = db - return nil, nil } @@ -306,7 +282,6 @@ func (b *databaseBackend) connectionWriteHandler(factory dbs.Factory) framework. return nil, err } - // Reset the DB connection resp := &logical.Response{} resp.AddWarning("Read access to this endpoint should be controlled via ACLs as it will return the connection string or URL as it is, including passwords, if any.") diff --git a/builtin/logical/database/path_role_create.go b/builtin/logical/database/path_role_create.go index 14b65cbb3106..d379ef26739b 100644 --- a/builtin/logical/database/path_role_create.go +++ b/builtin/logical/database/path_role_create.go @@ -27,34 +27,28 @@ func pathRoleCreate(b *databaseBackend) *framework.Path { } func (b *databaseBackend) pathRoleCreateRead(req *logical.Request, data *framework.FieldData) (*logical.Response, error) { - b.logger.Trace("postgres/pathRoleCreateRead: enter") - defer b.logger.Trace("postgres/pathRoleCreateRead: exit") - name := data.Get("name").(string) // Get the role - b.logger.Trace("postgres/pathRoleCreateRead: getting role") role, err := b.Role(req.Storage, name) if err != nil { return nil, err } if role == nil { - return logical.ErrorResponse(fmt.Sprintf("unknown role: %s", name)), nil + return logical.ErrorResponse(fmt.Sprintf("Unknown role: %s", name)), nil } - // Generate the username, password and expiration - - // Get our handle - b.logger.Trace("postgres/pathRoleCreateRead: getting database handle") + b.Lock() + defer b.Unlock() - b.RLock() - defer b.RUnlock() - db, ok := b.connections[role.DBName] - if !ok { + // Get the Database object + db, err := b.getOrCreateDBObj(req.Storage, role.DBName) + if err != nil { // TODO: return a resp error instead? - return nil, fmt.Errorf("Cound not find DB with name: %s", role.DBName) + return nil, fmt.Errorf("cound not retrieve db with name: %s, got error: %s", role.DBName, err) } + // Generate the username, password and expiration username, err := db.GenerateUsername(req.DisplayName) if err != nil { return nil, err @@ -70,12 +64,12 @@ func (b *databaseBackend) pathRoleCreateRead(req *logical.Request, data *framewo return nil, err } + // Create the user err = db.CreateUser(role.Statements, username, password, expiration) if err != nil { return nil, err } - b.logger.Trace("postgres/pathRoleCreateRead: generating secret") resp := b.Secret(SecretCredsType).Response(map[string]interface{}{ "username": username, "password": password, diff --git a/builtin/logical/database/path_roles.go b/builtin/logical/database/path_roles.go index 9a5bb9324dfd..6f62c79d98a7 100644 --- a/builtin/logical/database/path_roles.go +++ b/builtin/logical/database/path_roles.go @@ -126,7 +126,14 @@ func (b *databaseBackend) pathRoleList(req *logical.Request, d *framework.FieldD func (b *databaseBackend) pathRoleCreate(req *logical.Request, data *framework.FieldData) (*logical.Response, error) { name := data.Get("name").(string) + if name == "" { + return logical.ErrorResponse("Empty role name attribute given"), nil + } + dbName := data.Get("db_name").(string) + if dbName == "" { + return logical.ErrorResponse("Empty database name attribute given"), nil + } // Get statements creationStmts := data.Get("creation_statements").(string) diff --git a/builtin/logical/database/secret_creds.go b/builtin/logical/database/secret_creds.go index e39525a18c42..2b63ea1f89b0 100644 --- a/builtin/logical/database/secret_creds.go +++ b/builtin/logical/database/secret_creds.go @@ -29,7 +29,7 @@ func (b *databaseBackend) secretCredsRenew(req *logical.Request, d *framework.Fi roleNameRaw, ok := req.Secret.InternalData["role"] if !ok { - return nil, fmt.Errorf("Could not find role with name: %s", req.Secret.InternalData["role"]) + return nil, fmt.Errorf("could not find role with name: %s", req.Secret.InternalData["role"]) } role, err := b.Role(req.Storage, roleNameRaw.(string)) @@ -37,7 +37,7 @@ func (b *databaseBackend) secretCredsRenew(req *logical.Request, d *framework.Fi return nil, err } if role == nil { - return nil, fmt.Errorf("Could not find role with name: %s", req.Secret.InternalData["role"]) + return nil, fmt.Errorf("could not find role with name: %s", req.Secret.InternalData["role"]) } f := framework.LeaseExtend(role.DefaultTTL, role.MaxTTL, b.System()) @@ -47,13 +47,13 @@ func (b *databaseBackend) secretCredsRenew(req *logical.Request, d *framework.Fi } // Grab the read lock - b.RLock() - defer b.RUnlock() + b.Lock() + defer b.Unlock() // Get our connection - db, ok := b.connections[role.DBName] - if !ok { - return nil, fmt.Errorf("Could not find connection with name %s", role.DBName) + db, err := b.getOrCreateDBObj(req.Storage, role.DBName) + if err != nil { + return nil, fmt.Errorf("could not find connection with name %s, got err: %s", role.DBName, err) } // Make sure we increase the VALID UNTIL endpoint for this user. @@ -81,7 +81,7 @@ func (b *databaseBackend) secretCredsRevoke(req *logical.Request, d *framework.F roleNameRaw, ok := req.Secret.InternalData["role"] if !ok { - return nil, fmt.Errorf("Could not find role with name: %s", req.Secret.InternalData["role"]) + return nil, fmt.Errorf("could not find role with name: %s", req.Secret.InternalData["role"]) } role, err := b.Role(req.Storage, roleNameRaw.(string)) @@ -89,7 +89,7 @@ func (b *databaseBackend) secretCredsRevoke(req *logical.Request, d *framework.F return nil, err } if role == nil { - return nil, fmt.Errorf("Could not find role with name: %s", req.Secret.InternalData["role"]) + return nil, fmt.Errorf("could not find role with name: %s", req.Secret.InternalData["role"]) } /* TODO: think about how to handle this case. @@ -109,13 +109,13 @@ func (b *databaseBackend) secretCredsRevoke(req *logical.Request, d *framework.F }*/ // Grab the read lock - b.RLock() - defer b.RUnlock() + b.Lock() + defer b.Unlock() // Get our connection - db, ok := b.connections[role.DBName] - if !ok { - return nil, fmt.Errorf("Could not find database with name: %s", role.DBName) + db, err := b.getOrCreateDBObj(req.Storage, role.DBName) + if err != nil { + return nil, fmt.Errorf("could not find database with name: %s, got error: %s", role.DBName, err) } err = db.RevokeUser(role.Statements, username) diff --git a/helper/pluginutil/tls.go b/helper/pluginutil/tls.go index 08f24985d225..63ae2932f172 100644 --- a/helper/pluginutil/tls.go +++ b/helper/pluginutil/tls.go @@ -217,6 +217,9 @@ func VaultPluginTLSProvider() (*tls.Config, error) { if err != nil { return nil, errwrap.Wrapf("error during token unwrap request: {{err}}", err) } + if secret == nil { + return nil, errors.New("error during token unwrap request secret is nil") + } // Retrieve and parse the CA Certificate CABytesRaw, ok := secret.Data["CACert"].(string) From 0c562fa3d767dd6a8992c30e21e7d7411a581855 Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Tue, 28 Mar 2017 12:20:17 -0700 Subject: [PATCH 044/152] Update tests --- builtin/logical/database/dbs/mysql_test.go | 14 ++++++++------ builtin/logical/database/dbs/plugin_test.go | 9 +++++---- builtin/logical/database/dbs/postgresql_test.go | 14 ++++++++------ 3 files changed, 21 insertions(+), 16 deletions(-) diff --git a/builtin/logical/database/dbs/mysql_test.go b/builtin/logical/database/dbs/mysql_test.go index f4d1247023c1..553acc8ffd52 100644 --- a/builtin/logical/database/dbs/mysql_test.go +++ b/builtin/logical/database/dbs/mysql_test.go @@ -7,6 +7,7 @@ import ( "testing" "time" + log "github.com/mgutz/logxi/v1" dockertest "gopkg.in/ory-am/dockertest.v2" ) @@ -70,21 +71,22 @@ func TestMySQL_Initialize(t *testing.T) { }, } - dbRaw, err := BuiltinFactory(conf, nil) + dbRaw, err := BuiltinFactory(conf, nil, &log.NullLogger{}) if err != nil { t.Fatalf("err: %s", err) } // Deconsturct the middleware chain to get the underlying mysql object - dbMetrics := dbRaw.(*databaseMetricsMiddleware) + dbTracer := dbRaw.(*databaseTracingMiddleware) + dbMetrics := dbTracer.next.(*databaseMetricsMiddleware) db := dbMetrics.next.(*MySQL) + connProducer := db.ConnectionProducer.(*sqlConnectionProducer) err = dbRaw.Initialize(conf.ConnectionDetails) if err != nil { t.Fatalf("err: %s", err) } - connProducer := db.ConnectionProducer.(*sqlConnectionProducer) if !connProducer.initalized { t.Fatal("Database should be initalized") } @@ -112,7 +114,7 @@ func TestMySQL_CreateUser(t *testing.T) { }, } - db, err := BuiltinFactory(conf, nil) + db, err := BuiltinFactory(conf, nil, &log.NullLogger{}) if err != nil { t.Fatalf("err: %s", err) } @@ -200,7 +202,7 @@ func TestMySQL_RenewUser(t *testing.T) { }, } - db, err := BuiltinFactory(conf, nil) + db, err := BuiltinFactory(conf, nil, &log.NullLogger{}) if err != nil { t.Fatalf("err: %s", err) } @@ -256,7 +258,7 @@ func TestMySQL_RevokeUser(t *testing.T) { }, } - db, err := BuiltinFactory(conf, nil) + db, err := BuiltinFactory(conf, nil, &log.NullLogger{}) if err != nil { t.Fatalf("err: %s", err) } diff --git a/builtin/logical/database/dbs/plugin_test.go b/builtin/logical/database/dbs/plugin_test.go index 151d0c88ffa1..60cb6814dd5f 100644 --- a/builtin/logical/database/dbs/plugin_test.go +++ b/builtin/logical/database/dbs/plugin_test.go @@ -16,6 +16,7 @@ import ( "github.com/hashicorp/vault/http" "github.com/hashicorp/vault/logical" "github.com/hashicorp/vault/vault" + log "github.com/mgutz/logxi/v1" ) type mockPlugin struct { @@ -134,7 +135,7 @@ func TestPlugin_Initialize(t *testing.T) { defer ln.Close() conf := getConf(t) - dbRaw, err := PluginFactory(conf, sys) + dbRaw, err := PluginFactory(conf, sys, &log.NullLogger{}) if err != nil { t.Fatalf("err: %s", err) } @@ -155,7 +156,7 @@ func TestPlugin_CreateUser(t *testing.T) { defer ln.Close() conf := getConf(t) - db, err := PluginFactory(conf, sys) + db, err := PluginFactory(conf, sys, &log.NullLogger{}) if err != nil { t.Fatalf("err: %s", err) } @@ -209,7 +210,7 @@ func TestPlugin_RenewUser(t *testing.T) { defer ln.Close() conf := getConf(t) - db, err := PluginFactory(conf, sys) + db, err := PluginFactory(conf, sys, &log.NullLogger{}) if err != nil { t.Fatalf("err: %s", err) } @@ -255,7 +256,7 @@ func TestPlugin_RevokeUser(t *testing.T) { defer ln.Close() conf := getConf(t) - db, err := PluginFactory(conf, sys) + db, err := PluginFactory(conf, sys, &log.NullLogger{}) if err != nil { t.Fatalf("err: %s", err) } diff --git a/builtin/logical/database/dbs/postgresql_test.go b/builtin/logical/database/dbs/postgresql_test.go index dab720920b8b..83aed50ba91b 100644 --- a/builtin/logical/database/dbs/postgresql_test.go +++ b/builtin/logical/database/dbs/postgresql_test.go @@ -7,6 +7,7 @@ import ( "testing" "time" + log "github.com/mgutz/logxi/v1" dockertest "gopkg.in/ory-am/dockertest.v2" ) @@ -77,21 +78,22 @@ func TestPostgreSQL_Initialize(t *testing.T) { }, } - dbRaw, err := BuiltinFactory(conf, nil) + dbRaw, err := BuiltinFactory(conf, nil, &log.NullLogger{}) if err != nil { t.Fatalf("err: %s", err) } // Deconsturct the middleware chain to get the underlying postgres object - dbMetrics := dbRaw.(*databaseMetricsMiddleware) + dbTracer := dbRaw.(*databaseTracingMiddleware) + dbMetrics := dbTracer.next.(*databaseMetricsMiddleware) db := dbMetrics.next.(*PostgreSQL) + connProducer := db.ConnectionProducer.(*sqlConnectionProducer) err = dbRaw.Initialize(conf.ConnectionDetails) if err != nil { t.Fatalf("err: %s", err) } - connProducer := db.ConnectionProducer.(*sqlConnectionProducer) if !connProducer.initalized { t.Fatal("Database should be initalized") } @@ -119,7 +121,7 @@ func TestPostgreSQL_CreateUser(t *testing.T) { }, } - db, err := BuiltinFactory(conf, nil) + db, err := BuiltinFactory(conf, nil, &log.NullLogger{}) if err != nil { t.Fatalf("err: %s", err) } @@ -212,7 +214,7 @@ func TestPostgreSQL_RenewUser(t *testing.T) { }, } - db, err := BuiltinFactory(conf, nil) + db, err := BuiltinFactory(conf, nil, &log.NullLogger{}) if err != nil { t.Fatalf("err: %s", err) } @@ -268,7 +270,7 @@ func TestPostgreSQL_RevokeUser(t *testing.T) { }, } - db, err := BuiltinFactory(conf, nil) + db, err := BuiltinFactory(conf, nil, &log.NullLogger{}) if err != nil { t.Fatalf("err: %s", err) } From 947fd66480ff4dd878def978eae9c3d22c1ed3f7 Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Tue, 28 Mar 2017 12:57:30 -0700 Subject: [PATCH 045/152] Cleanup the db factory code and add comments --- builtin/logical/database/dbs/db.go | 24 ++++++++++++++++++++---- 1 file changed, 20 insertions(+), 4 deletions(-) diff --git a/builtin/logical/database/dbs/db.go b/builtin/logical/database/dbs/db.go index 74f5a2605718..2637a73d1055 100644 --- a/builtin/logical/database/dbs/db.go +++ b/builtin/logical/database/dbs/db.go @@ -20,11 +20,15 @@ const ( var ( ErrUnsupportedDatabaseType = errors.New("unsupported database type") ErrEmptyCreationStatement = errors.New("empty creation statements") + ErrEmptyPluginCommand = errors.New("empty plugin command") + ErrEmptyPluginChecksum = errors.New("empty plugin checksum") ) -// Factory function for +// Factory function definition type Factory func(*DatabaseConfig, logical.SystemView, log.Logger) (DatabaseType, error) +// BuiltinFactory is used to build builtin database types. It wraps the database +// object in a logging and metrics middleware. func BuiltinFactory(conf *DatabaseConfig, sys logical.SystemView, logger log.Logger) (DatabaseType, error) { var dbType DatabaseType @@ -88,15 +92,20 @@ func BuiltinFactory(conf *DatabaseConfig, sys logical.SystemView, logger log.Log return dbType, nil } +// PluginFactory is used to build plugin database types. It wraps the database +// object in a logging and metrics middleware. func PluginFactory(conf *DatabaseConfig, sys logical.SystemView, logger log.Logger) (DatabaseType, error) { if conf.PluginCommand == "" { - return nil, errors.New("ERROR") + return nil, ErrEmptyPluginCommand } if conf.PluginChecksum == "" { - return nil, errors.New("ERROR") + return nil, ErrEmptyPluginChecksum } + // Make sure the database type is set to plugin + conf.DatabaseType = pluginTypeName + db, err := newPluginClient(sys, conf.PluginCommand, conf.PluginChecksum) if err != nil { return nil, err @@ -118,6 +127,7 @@ func PluginFactory(conf *DatabaseConfig, sys logical.SystemView, logger log.Logg return db, nil } +// DatabaseType is the interface that all database objects must implement. type DatabaseType interface { Type() string CreateUser(statements Statements, username, password, expiration string) error @@ -129,8 +139,12 @@ type DatabaseType interface { CredentialsProducer } +// DatabaseConfig is used by the Factory function to configure a DatabaseType +// object. type DatabaseConfig struct { - DatabaseType string `json:"type" structs:"type" mapstructure:"type"` + DatabaseType string `json:"type" structs:"type" mapstructure:"type"` + // ConnectionDetails stores the database specific connection settings needed + // by each database type. ConnectionDetails map[string]interface{} `json:"connection_details" structs:"connection_details" mapstructure:"connection_details"` MaxOpenConnections int `json:"max_open_connections" structs:"max_open_connections" mapstructure:"max_open_connections"` MaxIdleConnections int `json:"max_idle_connections" structs:"max_idle_connections" mapstructure:"max_idle_connections"` @@ -139,6 +153,8 @@ type DatabaseConfig struct { PluginChecksum string `json:"plugin_checksum" structs:"plugin_checksum" mapstructure:"plugin_checksum"` } +// GetFactory returns the appropriate factory method for the given database +// type. func (dc *DatabaseConfig) GetFactory() Factory { if dc.DatabaseType == pluginTypeName { return PluginFactory From 8ef78f06107c43024516e94f0c6c0272a3834c1c Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Tue, 28 Mar 2017 13:08:11 -0700 Subject: [PATCH 046/152] Add comments to connection and credential producers --- builtin/logical/database/dbs/connectionproducer.go | 7 ++++++- builtin/logical/database/dbs/credentialsproducer.go | 8 ++++++-- 2 files changed, 12 insertions(+), 3 deletions(-) diff --git a/builtin/logical/database/dbs/connectionproducer.go b/builtin/logical/database/dbs/connectionproducer.go index dae8d9400e23..ca9e7250e268 100644 --- a/builtin/logical/database/dbs/connectionproducer.go +++ b/builtin/logical/database/dbs/connectionproducer.go @@ -23,6 +23,9 @@ var ( errNotInitalized = errors.New("connection has not been initalized") ) +// ConnectionProducer can be used as an embeded interface in the DatabaseType +// definition. It implements the methods dealing with individual database +// connections and is used in all the builtin database types. type ConnectionProducer interface { Close() error Initialize(map[string]interface{}) error @@ -31,7 +34,7 @@ type ConnectionProducer interface { connection() (interface{}, error) } -// sqlConnectionProducer impliments ConnectionProducer and provides a generic producer for most sql databases +// sqlConnectionProducer implements ConnectionProducer and provides a generic producer for most sql databases type sqlConnectionProducer struct { ConnectionURL string `json:"connection_url" structs:"connection_url" mapstructure:"connection_url"` @@ -111,6 +114,8 @@ func (c *sqlConnectionProducer) Close() error { return nil } +// cassandraConnectionProducer implements ConnectionProducer and provides an +// interface for cassandra databases to make connections. type cassandraConnectionProducer struct { Hosts string `json:"hosts" structs:"hosts" mapstructure:"hosts"` Username string `json:"username" structs:"username" mapstructure:"username"` diff --git a/builtin/logical/database/dbs/credentialsproducer.go b/builtin/logical/database/dbs/credentialsproducer.go index 5ae3b128e6c2..6bd543f4e175 100644 --- a/builtin/logical/database/dbs/credentialsproducer.go +++ b/builtin/logical/database/dbs/credentialsproducer.go @@ -8,20 +8,22 @@ import ( uuid "github.com/hashicorp/go-uuid" ) +// CredentialsProducer can be used as an embeded interface in the DatabaseType +// definition. It implements the methods for generating user information for a +// particular database type and is used in all the builtin database types. type CredentialsProducer interface { GenerateUsername(displayName string) (string, error) GeneratePassword() (string, error) GenerateExpiration(ttl time.Duration) (string, error) } -// sqlCredentialsProducer impliments CredentialsProducer and provides a generic credentials producer for most sql database types. +// sqlCredentialsProducer implements CredentialsProducer and provides a generic credentials producer for most sql database types. type sqlCredentialsProducer struct { displayNameLen int usernameLen int } func (scp *sqlCredentialsProducer) GenerateUsername(displayName string) (string, error) { - // Generate the username, password and expiration. PG limits user to 63 characters if scp.displayNameLen > 0 && len(displayName) > scp.displayNameLen { displayName = displayName[:scp.displayNameLen] } @@ -52,6 +54,8 @@ func (scp *sqlCredentialsProducer) GenerateExpiration(ttl time.Duration) (string Format("2006-01-02 15:04:05-0700"), nil } +// cassandraCredentialsProducer implements CredentialsProducer and provides an +// interface for cassandra databases to generate user information. type cassandraCredentialsProducer struct{} func (ccp *cassandraCredentialsProducer) GenerateUsername(displayName string) (string, error) { From 1d3d3b780350c9265b6f5a8f5bbce35713855112 Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Tue, 28 Mar 2017 14:37:57 -0700 Subject: [PATCH 047/152] fix for plugin commands that have more than one paramater --- builtin/logical/database/dbs/plugin.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/builtin/logical/database/dbs/plugin.go b/builtin/logical/database/dbs/plugin.go index b1f9abe20cea..4bac0d16e190 100644 --- a/builtin/logical/database/dbs/plugin.go +++ b/builtin/logical/database/dbs/plugin.go @@ -84,7 +84,7 @@ func newPluginClient(sys pluginutil.Wrapper, command, checksum string) (Database commandArr := strings.Split(command, " ") var cmd *exec.Cmd if len(commandArr) > 1 { - cmd = exec.Command(commandArr[0], commandArr[1]) + cmd = exec.Command(commandArr[0], commandArr[1:]...) } else { cmd = exec.Command(commandArr[0]) } From 2b08521ab6c97284c2843dd3c1a73bd10aa0fa77 Mon Sep 17 00:00:00 2001 From: Calvin Leung Huang Date: Mon, 3 Apr 2017 12:59:30 -0400 Subject: [PATCH 048/152] Database refactor mssql (#2562) * WIP on mssql secret backend refactor * Add RevokeUser test, and use sqlserver driver internally * Remove debug statements * Fix code comment --- .../database/dbs/connectionproducer.go | 9 +- builtin/logical/database/dbs/db.go | 17 +- builtin/logical/database/dbs/mssql.go | 219 +++++++++++++++++ builtin/logical/database/dbs/mssql_test.go | 221 ++++++++++++++++++ 4 files changed, 464 insertions(+), 2 deletions(-) create mode 100644 builtin/logical/database/dbs/mssql.go create mode 100644 builtin/logical/database/dbs/mssql_test.go diff --git a/builtin/logical/database/dbs/connectionproducer.go b/builtin/logical/database/dbs/connectionproducer.go index ca9e7250e268..b5dc93951777 100644 --- a/builtin/logical/database/dbs/connectionproducer.go +++ b/builtin/logical/database/dbs/connectionproducer.go @@ -10,6 +10,7 @@ import ( "time" // Import sql drivers + _ "github.com/denisenkom/go-mssqldb" _ "github.com/go-sql-driver/mysql" _ "github.com/lib/pq" "github.com/mitchellh/mapstructure" @@ -73,6 +74,12 @@ func (c *sqlConnectionProducer) connection() (interface{}, error) { c.db.Close() } + // For mssql backend, switch to sqlserver instead + dbType := c.config.DatabaseType + if c.config.DatabaseType == "mssql" { + dbType = "sqlserver" + } + // Otherwise, attempt to make connection conn := c.ConnectionURL @@ -86,7 +93,7 @@ func (c *sqlConnectionProducer) connection() (interface{}, error) { } var err error - c.db, err = sql.Open(c.config.DatabaseType, conn) + c.db, err = sql.Open(dbType, conn) if err != nil { return nil, err } diff --git a/builtin/logical/database/dbs/db.go b/builtin/logical/database/dbs/db.go index 2637a73d1055..cf8f8ee7fdd1 100644 --- a/builtin/logical/database/dbs/db.go +++ b/builtin/logical/database/dbs/db.go @@ -13,6 +13,7 @@ import ( const ( postgreSQLTypeName = "postgres" mySQLTypeName = "mysql" + msSQLTypeName = "mssql" cassandraTypeName = "cassandra" pluginTypeName = "plugin" ) @@ -61,6 +62,20 @@ func BuiltinFactory(conf *DatabaseConfig, sys logical.SystemView, logger log.Log CredentialsProducer: credsProducer, } + case msSQLTypeName: + connProducer := &sqlConnectionProducer{} + connProducer.config = conf + + credsProducer := &sqlCredentialsProducer{ + displayNameLen: 10, + usernameLen: 63, + } + + dbType = &MSSQL{ + ConnectionProducer: connProducer, + CredentialsProducer: credsProducer, + } + case cassandraTypeName: connProducer := &cassandraConnectionProducer{} connProducer.config = conf @@ -163,7 +178,7 @@ func (dc *DatabaseConfig) GetFactory() Factory { return BuiltinFactory } -// Statments set in role creation and passed into the database type's functions. +// Statements set in role creation and passed into the database type's functions. // TODO: Add a way of setting defaults here. type Statements struct { CreationStatements string `json:"creation_statments" mapstructure:"creation_statements" structs:"creation_statments"` diff --git a/builtin/logical/database/dbs/mssql.go b/builtin/logical/database/dbs/mssql.go new file mode 100644 index 000000000000..b7439b0a82ca --- /dev/null +++ b/builtin/logical/database/dbs/mssql.go @@ -0,0 +1,219 @@ +package dbs + +import ( + "database/sql" + "fmt" + "strings" + + "github.com/hashicorp/vault/helper/strutil" +) + +// MSSQL is an implementation of DatabaseType interface +type MSSQL struct { + ConnectionProducer + CredentialsProducer +} + +// Type returns the TypeName for this backend +func (m *MSSQL) Type() string { + return msSQLTypeName +} + +func (m *MSSQL) getConnection() (*sql.DB, error) { + db, err := m.connection() + if err != nil { + return nil, err + } + + return db.(*sql.DB), nil +} + +// CreateUser generates the username/password on the underlying MSSQL secret backend as instructed by +// the CreationStatement provided. +func (m *MSSQL) CreateUser(statements Statements, username, password, expiration string) error { + // Grab the lock + m.Lock() + defer m.Unlock() + + // Get the connection + db, err := m.getConnection() + if err != nil { + return err + } + + if statements.CreationStatements == "" { + return ErrEmptyCreationStatement + } + + // Start a transaction + tx, err := db.Begin() + if err != nil { + return err + } + defer tx.Rollback() + + // Execute each query + for _, query := range strutil.ParseArbitraryStringSlice(statements.CreationStatements, ";") { + query = strings.TrimSpace(query) + if len(query) == 0 { + continue + } + + stmt, err := tx.Prepare(queryHelper(query, map[string]string{ + "name": username, + "password": password, + })) + if err != nil { + return err + } + defer stmt.Close() + if _, err := stmt.Exec(); err != nil { + return err + } + } + + // Commit the transaction + if err := tx.Commit(); err != nil { + return err + } + + return nil +} + +// RenewUser is not supported on MSSQL, so this is a no-op. +func (m *MSSQL) RenewUser(statements Statements, username, expiration string) error { + // NOOP + return nil +} + +// RevokeUser attempts to drop the specified user. It will first attempt to disable login, +// then kill pending connections from that user, and finally drop the user and login from the +// database instance. +func (m *MSSQL) RevokeUser(statements Statements, username string) error { + // Get connection + db, err := m.getConnection() + if err != nil { + return err + } + + // First disable server login + disableStmt, err := db.Prepare(fmt.Sprintf("ALTER LOGIN [%s] DISABLE;", username)) + if err != nil { + return err + } + defer disableStmt.Close() + if _, err := disableStmt.Exec(); err != nil { + return err + } + + // Query for sessions for the login so that we can kill any outstanding + // sessions. There cannot be any active sessions before we drop the logins + // This isn't done in a transaction because even if we fail along the way, + // we want to remove as much access as possible + sessionStmt, err := db.Prepare(fmt.Sprintf( + "SELECT session_id FROM sys.dm_exec_sessions WHERE login_name = '%s';", username)) + if err != nil { + return err + } + defer sessionStmt.Close() + + sessionRows, err := sessionStmt.Query() + if err != nil { + return err + } + defer sessionRows.Close() + + var revokeStmts []string + for sessionRows.Next() { + var sessionID int + err = sessionRows.Scan(&sessionID) + if err != nil { + return err + } + revokeStmts = append(revokeStmts, fmt.Sprintf("KILL %d;", sessionID)) + } + + // Query for database users using undocumented stored procedure for now since + // it is the easiest way to get this information; + // we need to drop the database users before we can drop the login and the role + // This isn't done in a transaction because even if we fail along the way, + // we want to remove as much access as possible + stmt, err := db.Prepare(fmt.Sprintf("EXEC sp_msloginmappings '%s';", username)) + if err != nil { + return err + } + defer stmt.Close() + + rows, err := stmt.Query() + if err != nil { + return err + } + defer rows.Close() + + for rows.Next() { + var loginName, dbName, qUsername string + var aliasName sql.NullString + err = rows.Scan(&loginName, &dbName, &qUsername, &aliasName) + if err != nil { + return err + } + revokeStmts = append(revokeStmts, fmt.Sprintf(dropUserSQL, dbName, username, username)) + } + + // we do not stop on error, as we want to remove as + // many permissions as possible right now + var lastStmtError error + for _, query := range revokeStmts { + stmt, err := db.Prepare(query) + if err != nil { + lastStmtError = err + continue + } + defer stmt.Close() + _, err = stmt.Exec() + if err != nil { + lastStmtError = err + } + } + + // can't drop if not all database users are dropped + if rows.Err() != nil { + return fmt.Errorf("cound not generate sql statements for all rows: %s", rows.Err()) + } + if lastStmtError != nil { + return fmt.Errorf("could not perform all sql statements: %s", lastStmtError) + } + + // Drop this login + stmt, err = db.Prepare(fmt.Sprintf(dropLoginSQL, username, username)) + if err != nil { + return err + } + defer stmt.Close() + if _, err := stmt.Exec(); err != nil { + return err + } + + return nil +} + +const dropUserSQL = ` +USE [%s] +IF EXISTS + (SELECT name + FROM sys.database_principals + WHERE name = N'%s') +BEGIN + DROP USER [%s] +END +` + +const dropLoginSQL = ` +IF EXISTS + (SELECT name + FROM master.sys.server_principals + WHERE name = N'%s') +BEGIN + DROP LOGIN [%s] +END +` diff --git a/builtin/logical/database/dbs/mssql_test.go b/builtin/logical/database/dbs/mssql_test.go new file mode 100644 index 000000000000..f2169299fa60 --- /dev/null +++ b/builtin/logical/database/dbs/mssql_test.go @@ -0,0 +1,221 @@ +package dbs + +import ( + "database/sql" + "fmt" + "os" + "sync" + "testing" + "time" + + _ "github.com/denisenkom/go-mssqldb" + log "github.com/mgutz/logxi/v1" + dockertest "gopkg.in/ory-am/dockertest.v3" +) + +var ( + testMSQLImagePull sync.Once +) + +func prepareMSSQLTestContainer(t *testing.T) (cleanup func(), retURL string) { + if os.Getenv("MSSQL_URL") != "" { + return func() {}, os.Getenv("MSSQL_URL") + } + + pool, err := dockertest.NewPool("") + if err != nil { + t.Fatalf("Failed to connect to docker: %s", err) + } + + resource, err := pool.Run("microsoft/mssql-server-linux", "latest", []string{"ACCEPT_EULA=Y", "SA_PASSWORD=yourStrong(!)Password"}) + if err != nil { + t.Fatalf("Could not start local MSSQL docker container: %s", err) + } + + cleanup = func() { + err := pool.Purge(resource) + if err != nil { + t.Fatalf("Failed to cleanup local DynamoDB: %s", err) + } + } + + retURL = fmt.Sprintf("sqlserver://sa:yourStrong(!)Password@localhost:%s", resource.GetPort("1433/tcp")) + + // exponential backoff-retry, because the mssql container may not be able to accept connections yet + if err = pool.Retry(func() error { + var err error + var db *sql.DB + db, err = sql.Open("mssql", retURL) + if err != nil { + return err + } + return db.Ping() + }); err != nil { + t.Fatalf("Could not connect to MSSQL docker container: %s", err) + } + + return +} + +func TestMSSQL_Initialize(t *testing.T) { + cleanup, connURL := prepareMSSQLTestContainer(t) + defer cleanup() + + conf := &DatabaseConfig{ + DatabaseType: msSQLTypeName, + ConnectionDetails: map[string]interface{}{ + "connection_url": connURL, + }, + } + + dbRaw, err := BuiltinFactory(conf, nil, &log.NullLogger{}) + if err != nil { + t.Fatalf("err: %s", err) + } + + // Deconsturct the middleware chain to get the underlying mssql object + dbTracer := dbRaw.(*databaseTracingMiddleware) + dbMetrics := dbTracer.next.(*databaseMetricsMiddleware) + db := dbMetrics.next.(*MSSQL) + connProducer := db.ConnectionProducer.(*sqlConnectionProducer) + + err = dbRaw.Initialize(conf.ConnectionDetails) + if err != nil { + t.Fatalf("err: %s", err) + } + + if !connProducer.initalized { + t.Fatal("Database should be initalized") + } + + err = dbRaw.Close() + if err != nil { + t.Fatalf("err: %s", err) + } + + if connProducer.db != nil { + t.Fatal("db object should be nil") + } +} + +func TestMSSQL_CreateUser(t *testing.T) { + cleanup, connURL := prepareMSSQLTestContainer(t) + defer cleanup() + + conf := &DatabaseConfig{ + DatabaseType: msSQLTypeName, + ConnectionDetails: map[string]interface{}{ + "connection_url": connURL, + }, + } + + db, err := BuiltinFactory(conf, nil, &log.NullLogger{}) + if err != nil { + t.Fatalf("err: %s", err) + } + err = db.Initialize(conf.ConnectionDetails) + if err != nil { + t.Fatalf("err: %s", err) + } + + username, err := db.GenerateUsername("test") + if err != nil { + t.Fatalf("err: %s", err) + } + + password, err := db.GeneratePassword() + if err != nil { + t.Fatalf("err: %s", err) + } + + expiration, err := db.GenerateExpiration(time.Minute) + if err != nil { + t.Fatalf("err: %s", err) + } + + // Test with no configured Creation Statememt + err = db.CreateUser(Statements{}, username, password, expiration) + if err == nil { + t.Fatal("Expected error when no creation statement is provided") + } + + statements := Statements{ + CreationStatements: testMSSQLRole, + } + + err = db.CreateUser(statements, username, password, expiration) + if err != nil { + t.Fatalf("err: %s", err) + } + + username, err = db.GenerateUsername("test") + if err != nil { + t.Fatalf("err: %s", err) + } + + password, err = db.GeneratePassword() + if err != nil { + t.Fatalf("err: %s", err) + } + + expiration, err = db.GenerateExpiration(time.Minute) + if err != nil { + t.Fatalf("err: %s", err) + } +} + +func TestMSSQL_RevokeUser(t *testing.T) { + cleanup, connURL := prepareMSSQLTestContainer(t) + defer cleanup() + + conf := &DatabaseConfig{ + DatabaseType: msSQLTypeName, + ConnectionDetails: map[string]interface{}{ + "connection_url": connURL, + }, + } + + db, err := BuiltinFactory(conf, nil, &log.NullLogger{}) + if err != nil { + t.Fatalf("err: %s", err) + } + err = db.Initialize(conf.ConnectionDetails) + if err != nil { + t.Fatalf("err: %s", err) + } + + username, err := db.GenerateUsername("test") + if err != nil { + t.Fatalf("err: %s", err) + } + + password, err := db.GeneratePassword() + if err != nil { + t.Fatalf("err: %s", err) + } + + expiration, err := db.GenerateExpiration(time.Minute) + if err != nil { + t.Fatalf("err: %s", err) + } + + statements := Statements{ + CreationStatements: testMSSQLRole, + } + + err = db.CreateUser(statements, username, password, expiration) + if err != nil { + t.Fatalf("err: %s", err) + } + + // Test default revoke statememts + err = db.RevokeUser(statements, username) + if err != nil { + t.Fatalf("err: %s", err) + } +} + +const testMSSQLRole = ` +CREATE LOGIN [{{name}}] WITH PASSWORD = '{{password}}'; +CREATE USER [{{name}}] FOR LOGIN [{{name}}]; +GRANT SELECT, INSERT, UPDATE, DELETE ON SCHEMA::dbo TO [{{name}}];` From ac519abecf06348e6e35828533985874a7007bb4 Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Mon, 3 Apr 2017 17:52:29 -0700 Subject: [PATCH 049/152] Plugin catalog --- builtin/logical/database/dbs/db.go | 17 ++- builtin/logical/database/dbs/plugin.go | 50 +-------- .../database/path_config_connection.go | 11 +- command/server.go | 13 +++ command/server/config.go | 4 +- helper/pluginutil/runner.go | 61 +++++++++++ logical/system_view.go | 7 ++ vault/core.go | 20 +++- vault/dynamic_system_view.go | 5 + vault/logical_system.go | 89 +++++++++++++++ vault/plugin_catalog.go | 101 ++++++++++++++++++ 11 files changed, 310 insertions(+), 68 deletions(-) create mode 100644 helper/pluginutil/runner.go create mode 100644 vault/plugin_catalog.go diff --git a/builtin/logical/database/dbs/db.go b/builtin/logical/database/dbs/db.go index 2637a73d1055..8d44a474e427 100644 --- a/builtin/logical/database/dbs/db.go +++ b/builtin/logical/database/dbs/db.go @@ -20,8 +20,7 @@ const ( var ( ErrUnsupportedDatabaseType = errors.New("unsupported database type") ErrEmptyCreationStatement = errors.New("empty creation statements") - ErrEmptyPluginCommand = errors.New("empty plugin command") - ErrEmptyPluginChecksum = errors.New("empty plugin checksum") + ErrEmptyPluginName = errors.New("empty plugin name") ) // Factory function definition @@ -95,18 +94,19 @@ func BuiltinFactory(conf *DatabaseConfig, sys logical.SystemView, logger log.Log // PluginFactory is used to build plugin database types. It wraps the database // object in a logging and metrics middleware. func PluginFactory(conf *DatabaseConfig, sys logical.SystemView, logger log.Logger) (DatabaseType, error) { - if conf.PluginCommand == "" { - return nil, ErrEmptyPluginCommand + if conf.PluginName == "" { + return nil, ErrEmptyPluginName } - if conf.PluginChecksum == "" { - return nil, ErrEmptyPluginChecksum + pluginMeta, err := sys.LookupPlugin(conf.PluginName) + if err != nil { + return nil, err } // Make sure the database type is set to plugin conf.DatabaseType = pluginTypeName - db, err := newPluginClient(sys, conf.PluginCommand, conf.PluginChecksum) + db, err := newPluginClient(sys, pluginMeta) if err != nil { return nil, err } @@ -149,8 +149,7 @@ type DatabaseConfig struct { MaxOpenConnections int `json:"max_open_connections" structs:"max_open_connections" mapstructure:"max_open_connections"` MaxIdleConnections int `json:"max_idle_connections" structs:"max_idle_connections" mapstructure:"max_idle_connections"` MaxConnectionLifetime time.Duration `json:"max_connection_lifetime" structs:"max_connection_lifetime" mapstructure:"max_connection_lifetime"` - PluginCommand string `json:"plugin_command" structs:"plugin_command" mapstructure:"plugin_command"` - PluginChecksum string `json:"plugin_checksum" structs:"plugin_checksum" mapstructure:"plugin_checksum"` + PluginName string `json:"plugin_name" structs:"plugin_name" mapstructure:"plugin_name"` } // GetFactory returns the appropriate factory method for the given database diff --git a/builtin/logical/database/dbs/plugin.go b/builtin/logical/database/dbs/plugin.go index 4bac0d16e190..791f3b465195 100644 --- a/builtin/logical/database/dbs/plugin.go +++ b/builtin/logical/database/dbs/plugin.go @@ -1,12 +1,8 @@ package dbs import ( - "crypto/sha256" - "encoding/hex" "fmt" "net/rpc" - "os/exec" - "strings" "sync" "time" @@ -55,59 +51,17 @@ func (dc *DatabasePluginClient) Close() error { // newPluginClient returns a databaseRPCClient with a connection to a running // plugin. The client is wrapped in a DatabasePluginClient object to ensure the // plugin is killed on call of Close(). -func newPluginClient(sys pluginutil.Wrapper, command, checksum string) (DatabaseType, error) { +func newPluginClient(sys pluginutil.Wrapper, pluginRunner *pluginutil.PluginRunner) (DatabaseType, error) { // pluginMap is the map of plugins we can dispense. var pluginMap = map[string]plugin.Plugin{ "database": new(DatabasePlugin), } - // Get a CA TLS Certificate - CACertBytes, CACert, CAKey, err := pluginutil.GenerateCACert() + client, err := pluginRunner.Run(sys, pluginMap, handshakeConfig, []string{}) if err != nil { return nil, err } - // Use CA to sign a client cert and return a configured TLS config - clientTLSConfig, err := pluginutil.CreateClientTLSConfig(CACert, CAKey) - if err != nil { - return nil, err - } - - // Use CA to sign a server cert and wrap the values in a response wrapped - // token. - wrapToken, err := pluginutil.WrapServerConfig(sys, CACertBytes, CACert, CAKey) - if err != nil { - return nil, err - } - - // Add the response wrap token to the ENV of the plugin - commandArr := strings.Split(command, " ") - var cmd *exec.Cmd - if len(commandArr) > 1 { - cmd = exec.Command(commandArr[0], commandArr[1:]...) - } else { - cmd = exec.Command(commandArr[0]) - } - cmd.Env = append(cmd.Env, fmt.Sprintf("%s=%s", pluginutil.PluginUnwrapTokenEnv, wrapToken)) - - checksumDecoded, err := hex.DecodeString(checksum) - if err != nil { - return nil, err - } - - secureConfig := &plugin.SecureConfig{ - Checksum: checksumDecoded, - Hash: sha256.New(), - } - - client := plugin.NewClient(&plugin.ClientConfig{ - HandshakeConfig: handshakeConfig, - Plugins: pluginMap, - Cmd: cmd, - TLSConfig: clientTLSConfig, - SecureConfig: secureConfig, - }) - // Connect via RPC rpcClient, err := client.Client() if err != nil { diff --git a/builtin/logical/database/path_config_connection.go b/builtin/logical/database/path_config_connection.go index b4c699750d55..a0494d71edee 100644 --- a/builtin/logical/database/path_config_connection.go +++ b/builtin/logical/database/path_config_connection.go @@ -112,13 +112,7 @@ reduced to the same size.`, a zero or negative value reuses connections forever.`, }, - "plugin_command": &framework.FieldSchema{ - Type: framework.TypeString, - Description: `Maximum amount of time a connection may be reused; - a zero or negative value reuses connections forever.`, - }, - - "plugin_checksum": &framework.FieldSchema{ + "plugin_name": &framework.FieldSchema{ Type: framework.TypeString, Description: `Maximum amount of time a connection may be reused; a zero or negative value reuses connections forever.`, @@ -223,8 +217,7 @@ func (b *databaseBackend) connectionWriteHandler(factory dbs.Factory) framework. MaxOpenConnections: maxOpenConns, MaxIdleConnections: maxIdleConns, MaxConnectionLifetime: maxConnLifetime, - PluginCommand: data.Get("plugin_command").(string), - PluginChecksum: data.Get("plugin_checksum").(string), + PluginName: data.Get("plugin_name").(string), } name := data.Get("name").(string) diff --git a/command/server.go b/command/server.go index 09658b94948a..d6eb0d76d95e 100644 --- a/command/server.go +++ b/command/server.go @@ -8,6 +8,7 @@ import ( "net/url" "os" "os/signal" + "path/filepath" "runtime" "sort" "strconv" @@ -20,6 +21,7 @@ import ( colorable "github.com/mattn/go-colorable" log "github.com/mgutz/logxi/v1" + homedir "github.com/mitchellh/go-homedir" "google.golang.org/grpc/grpclog" @@ -237,11 +239,22 @@ func (c *ServerCommand) Run(args []string) int { DefaultLeaseTTL: config.DefaultLeaseTTL, ClusterName: config.ClusterName, CacheSize: config.CacheSize, + PluginDirectory: config.PluginDirectory, } if dev { coreConfig.DevToken = devRootTokenID } + if config.PluginDirectory == "" { + homePath, err := homedir.Dir() + if err != nil { + c.Ui.Output(fmt.Sprintf( + "Error getting user's home directory: %v", err)) + return 1 + } + coreConfig.PluginDirectory = filepath.Join(homePath, "/vault-plugins/") + } + var disableClustering bool // Initialize the separate HA physical backend, if it exists diff --git a/command/server/config.go b/command/server/config.go index 00edd5de9342..a57fdad13bf9 100644 --- a/command/server/config.go +++ b/command/server/config.go @@ -38,7 +38,8 @@ type Config struct { DefaultLeaseTTL time.Duration `hcl:"-"` DefaultLeaseTTLRaw string `hcl:"default_lease_ttl"` - ClusterName string `hcl:"cluster_name"` + ClusterName string `hcl:"cluster_name"` + PluginDirectory string `hcl:"plugin_directory"` } // DevConfig is a Config that is used for dev mode of Vault. @@ -339,6 +340,7 @@ func ParseConfig(d string, logger log.Logger) (*Config, error) { "default_lease_ttl", "max_lease_ttl", "cluster_name", + "plugin_directory", } if err := checkHCLKeys(list, valid); err != nil { return nil, err diff --git a/helper/pluginutil/runner.go b/helper/pluginutil/runner.go new file mode 100644 index 000000000000..143a4c839145 --- /dev/null +++ b/helper/pluginutil/runner.go @@ -0,0 +1,61 @@ +package pluginutil + +import ( + "crypto/sha256" + "fmt" + "os/exec" + + plugin "github.com/hashicorp/go-plugin" +) + +type Looker interface { + LookupPlugin(string) (*PluginRunner, error) +} + +type PluginRunner struct { + Name string `json:"name"` + Command string `json:"command"` + Args []string `json:"args"` + Sha256 []byte `json:"sha256"` +} + +func (r *PluginRunner) Run(wrapper Wrapper, pluginMap map[string]plugin.Plugin, hs plugin.HandshakeConfig, env []string) (*plugin.Client, error) { + // Get a CA TLS Certificate + CACertBytes, CACert, CAKey, err := GenerateCACert() + if err != nil { + return nil, err + } + + // Use CA to sign a client cert and return a configured TLS config + clientTLSConfig, err := CreateClientTLSConfig(CACert, CAKey) + if err != nil { + return nil, err + } + + // Use CA to sign a server cert and wrap the values in a response wrapped + // token. + wrapToken, err := WrapServerConfig(wrapper, CACertBytes, CACert, CAKey) + if err != nil { + return nil, err + } + + // Add the response wrap token to the ENV of the plugin + cmd := exec.Command(r.Command, r.Args...) + cmd.Env = append(cmd.Env, env...) + cmd.Env = append(cmd.Env, fmt.Sprintf("%s=%s", PluginUnwrapTokenEnv, wrapToken)) + + secureConfig := &plugin.SecureConfig{ + Checksum: r.Sha256, + Hash: sha256.New(), + } + + client := plugin.NewClient(&plugin.ClientConfig{ + HandshakeConfig: hs, + Plugins: pluginMap, + Cmd: cmd, + TLSConfig: clientTLSConfig, + SecureConfig: secureConfig, + }) + + return client, nil +} diff --git a/logical/system_view.go b/logical/system_view.go index 56254b33a17a..a9626bc50ee6 100644 --- a/logical/system_view.go +++ b/logical/system_view.go @@ -5,6 +5,7 @@ import ( "time" "github.com/hashicorp/vault/helper/consts" + "github.com/hashicorp/vault/helper/pluginutil" ) // SystemView exposes system configuration information in a safe way @@ -42,6 +43,8 @@ type SystemView interface { // ResponseWrapData wraps the given data in a cubbyhole and returns the // token used to unwrap. ResponseWrapData(data map[string]interface{}, ttl time.Duration, jwt bool) (string, error) + + LookupPlugin(string) (*pluginutil.PluginRunner, error) } type StaticSystemView struct { @@ -81,3 +84,7 @@ func (d StaticSystemView) ReplicationState() consts.ReplicationState { func (d StaticSystemView) ResponseWrapData(data map[string]interface{}, ttl time.Duration, jwt bool) (string, error) { return "", errors.New("ResponseWrapData is not implimented in StaticSystemView") } + +func (d StaticSystemView) LookupPlugin(name string) (*pluginutil.PluginRunner, error) { + return nil, errors.New("LookupPlugin is not implimented in StaticSystemView") +} diff --git a/vault/core.go b/vault/core.go index ea378fa8ad94..08a828643ad9 100644 --- a/vault/core.go +++ b/vault/core.go @@ -10,6 +10,7 @@ import ( "net" "net/http" "net/url" + "path/filepath" "sync" "time" @@ -330,6 +331,12 @@ type Core struct { // uiEnabled indicates whether Vault Web UI is enabled or not uiEnabled bool + + // pluginDirectory is the location vault will look for plugins + pluginDirectory string + + // pluginCatalog is used to manage plugin configurations + pluginCatalog *PluginCatalog } // CoreConfig is used to parameterize a core @@ -374,6 +381,8 @@ type CoreConfig struct { EnableUI bool `json:"ui" structs:"ui" mapstructure:"ui"` + PluginDirectory string `json:"plugin_directory" structs:"plugin_directory" mapstructure:"plugin_directory"` + ReloadFuncs *map[string][]ReloadFunc ReloadFuncsLock *sync.RWMutex } @@ -453,8 +462,13 @@ func NewCore(conf *CoreConfig) (*Core, error) { } } - // Construct a new AES-GCM barrier var err error + c.pluginDirectory, err = filepath.Abs(conf.PluginDirectory) + if err != nil { + return nil, fmt.Errorf("core setup failed: %v", err) + } + + // Construct a new AES-GCM barrier c.barrier, err = NewAESGCMBarrier(c.physical) if err != nil { return nil, fmt.Errorf("barrier setup failed: %v", err) @@ -1280,6 +1294,10 @@ func (c *Core) postUnseal() (retErr error) { if err := c.setupAuditedHeadersConfig(); err != nil { return err } + if err := c.setupPluginCatalog(); err != nil { + return err + } + if c.ha != nil { if err := c.startClusterListener(); err != nil { return err diff --git a/vault/dynamic_system_view.go b/vault/dynamic_system_view.go index 4c6807ace930..f318f3ab13cd 100644 --- a/vault/dynamic_system_view.go +++ b/vault/dynamic_system_view.go @@ -4,6 +4,7 @@ import ( "time" "github.com/hashicorp/vault/helper/consts" + "github.com/hashicorp/vault/helper/pluginutil" "github.com/hashicorp/vault/logical" ) @@ -114,3 +115,7 @@ func (d dynamicSystemView) ResponseWrapData(data map[string]interface{}, ttl tim return resp.WrapInfo.Token, nil } + +func (d dynamicSystemView) LookupPlugin(name string) (*pluginutil.PluginRunner, error) { + return d.core.pluginCatalog.Get(name) +} diff --git a/vault/logical_system.go b/vault/logical_system.go index 1c439506ca79..f5dbe2affa69 100644 --- a/vault/logical_system.go +++ b/vault/logical_system.go @@ -63,6 +63,8 @@ func NewSystemBackend(core *Core, config *logical.BackendConfig) (logical.Backen "replication/reindex", "rotate", "config/auditing/*", + "plugin-catalog", + "plugin-catalog/*", }, Unauthenticated: []string{ @@ -692,6 +694,30 @@ func NewSystemBackend(core *Core, config *logical.BackendConfig) (logical.Backen HelpSynopsis: strings.TrimSpace(sysHelp["audited-headers"][0]), HelpDescription: strings.TrimSpace(sysHelp["audited-headers"][1]), }, + &framework.Path{ + Pattern: "plugin-catalog/(?P.+)", + + Fields: map[string]*framework.FieldSchema{ + "name": &framework.FieldSchema{ + Type: framework.TypeString, + }, + "sha_256": &framework.FieldSchema{ + Type: framework.TypeString, + }, + "command": &framework.FieldSchema{ + Type: framework.TypeString, + }, + }, + + Callbacks: map[logical.Operation]framework.OperationFunc{ + logical.UpdateOperation: b.handlePluginCatalogUpdate, + logical.DeleteOperation: b.handlePluginCatalogDelete, + logical.ReadOperation: b.handlePluginCatalogRead, + }, + + HelpSynopsis: strings.TrimSpace(sysHelp["audited-headers-name"][0]), + HelpDescription: strings.TrimSpace(sysHelp["audited-headers-name"][1]), + }, }, } @@ -724,6 +750,69 @@ func (b *SystemBackend) invalidate(key string) { } } +func (b *SystemBackend) handlePluginCatalogUpdate(req *logical.Request, d *framework.FieldData) (*logical.Response, error) { + pluginName := d.Get("name").(string) + if pluginName == "" { + return logical.ErrorResponse("missing plugin name"), nil + } + + sha256 := d.Get("sha_256").(string) + if sha256 == "" { + return logical.ErrorResponse("missing SHA-256 value"), nil + } + + command := d.Get("command").(string) + if command == "" { + return logical.ErrorResponse("missing command value"), nil + } + + sha256Bytes, err := hex.DecodeString(sha256) + if err != nil { + return logical.ErrorResponse("Could not decode SHA-256 value from Hex"), err + } + + err = b.Core.pluginCatalog.Set(pluginName, command, sha256Bytes) + if err != nil { + return nil, err + } + + return nil, nil +} + +func (b *SystemBackend) handlePluginCatalogRead(req *logical.Request, d *framework.FieldData) (*logical.Response, error) { + pluginName := d.Get("name").(string) + if pluginName == "" { + return logical.ErrorResponse("missing plugin name"), nil + } + plugin, err := b.Core.pluginCatalog.Get(pluginName) + if err != nil { + return nil, err + } + + return &logical.Response{ + Data: map[string]interface{}{ + "plugin": plugin, + }, + }, nil +} + +func (b *SystemBackend) handlePluginCatalogDelete(req *logical.Request, d *framework.FieldData) (*logical.Response, error) { + pluginName := d.Get("name").(string) + if pluginName == "" { + return logical.ErrorResponse("missing plugin name"), nil + } + plugin, err := b.Core.pluginCatalog.Get(pluginName) + if err != nil { + return nil, err + } + + return &logical.Response{ + Data: map[string]interface{}{ + "plugin": plugin, + }, + }, nil +} + // handleAuditedHeaderUpdate creates or overwrites a header entry func (b *SystemBackend) handleAuditedHeaderUpdate(req *logical.Request, d *framework.FieldData) (*logical.Response, error) { header := d.Get("header").(string) diff --git a/vault/plugin_catalog.go b/vault/plugin_catalog.go new file mode 100644 index 000000000000..c1f504d2c2be --- /dev/null +++ b/vault/plugin_catalog.go @@ -0,0 +1,101 @@ +package vault + +import ( + "encoding/json" + "errors" + "fmt" + "path/filepath" + "strings" + "sync" + + "github.com/hashicorp/vault/helper/jsonutil" + "github.com/hashicorp/vault/helper/pluginutil" + "github.com/hashicorp/vault/logical" +) + +var ( + pluginCatalogPrefix = "plugin-catalog/" +) + +type PluginCatalog struct { + catalogView *BarrierView + directory string + + lock sync.RWMutex + builtin map[string]*pluginutil.PluginRunner +} + +func NewPluginCatalog(view *BarrierView, directory string) *PluginCatalog { + return &PluginCatalog{ + catalogView: view.SubView(pluginCatalogPrefix), + directory: directory, + } +} + +func (c *Core) setupPluginCatalog() error { + catalog := NewPluginCatalog(c.systemBarrierView, c.pluginDirectory) + c.pluginCatalog = catalog + + return nil +} + +func (c *PluginCatalog) Get(name string) (*pluginutil.PluginRunner, error) { + out, err := c.catalogView.Get(name) + if err != nil { + return nil, fmt.Errorf("failed to retrieve plugin \"%s\": %v", name, err) + } + if out == nil { + return nil, fmt.Errorf("no plugin found with name: %s", name) + } + + entry := new(pluginutil.PluginRunner) + if err := jsonutil.DecodeJSON(out.Value, entry); err != nil { + return nil, fmt.Errorf("failed to decode plugin entry: %v", err) + } + + return entry, nil +} + +func (c *PluginCatalog) Set(name, command string, sha256 []byte) error { + parts := strings.Split(command, " ") + command = parts[0] + args := parts[1:] + + command = filepath.Join(c.directory, command) + + // Best effort check to make sure the command isn't breaking out of the + // configured plugin directory. + sym, err := filepath.EvalSymlinks(command) + if err != nil { + return fmt.Errorf("error while validating the command path: %v", err) + } + symAbs, err := filepath.Abs(filepath.Dir(sym)) + if err != nil { + return fmt.Errorf("error while validating the command path: %v", err) + } + + if symAbs != c.directory { + return errors.New("can not execute files outside of configured plugin directory") + } + + entry := &pluginutil.PluginRunner{ + Name: name, + Command: command, + Args: args, + Sha256: sha256, + } + + buf, err := json.Marshal(entry) + if err != nil { + return fmt.Errorf("failed to encode plugin entry: %v", err) + } + + logicalEntry := logical.StorageEntry{ + Key: name, + Value: buf, + } + if err := c.catalogView.Put(&logicalEntry); err != nil { + return fmt.Errorf("failed to persist plugin entry: %v", err) + } + return nil +} From 1faa5fc020c2a46b11869bd766b0a6ffe4321aca Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Mon, 3 Apr 2017 18:30:38 -0700 Subject: [PATCH 050/152] On change of configuration rotate the database type --- .../database/path_config_connection.go | 20 +++++++------------ 1 file changed, 7 insertions(+), 13 deletions(-) diff --git a/builtin/logical/database/path_config_connection.go b/builtin/logical/database/path_config_connection.go index a0494d71edee..a1d32d572f65 100644 --- a/builtin/logical/database/path_config_connection.go +++ b/builtin/logical/database/path_config_connection.go @@ -243,29 +243,23 @@ func (b *databaseBackend) connectionWriteHandler(factory dbs.Factory) framework. } if verifyConnection { - return logical.ErrorResponse(err.Error()), nil + return logical.ErrorResponse("Could not verify connection"), nil } } if _, ok := b.connections[name]; ok { - newType := db.Type() - - // Don't update connection until the reset api is hit, close for - // now. - err = db.Close() + // Close and remove the old connection + err := b.connections[name].Close() if err != nil { return nil, err } - // Don't allow the connection type to change - if b.connections[name].Type() != newType { - return logical.ErrorResponse("Can not change type of existing connection."), nil - } - } else { - // Save the new connection - b.connections[name] = db + delete(b.connections, name) } + // Save the new connection + b.connections[name] = db + // Store it entry, err := logical.StorageEntryJSON(fmt.Sprintf("dbs/%s", name), config) if err != nil { From 8e3cb50bfc58e86a10dbeec512d49ffc4051b957 Mon Sep 17 00:00:00 2001 From: Calvin Leung Huang Date: Tue, 4 Apr 2017 14:32:42 -0400 Subject: [PATCH 051/152] Database refactor invalidate (#2566) * WIP on invalidate function * cassandraConnectionProducer has Close() * Delete database from connections map on successful db.Close() * Move clear connection into its own func * Use const for database config path --- builtin/logical/database/backend.go | 31 +++++++++++++++++-- .../database/path_config_connection.go | 8 ++--- 2 files changed, 31 insertions(+), 8 deletions(-) diff --git a/builtin/logical/database/backend.go b/builtin/logical/database/backend.go index f8bcc60f1d0a..4d069a432899 100644 --- a/builtin/logical/database/backend.go +++ b/builtin/logical/database/backend.go @@ -12,6 +12,8 @@ import ( "github.com/hashicorp/vault/logical/framework" ) +const databaseConfigPath = "database/dbs/" + func Factory(conf *logical.BackendConfig) (logical.Backend, error) { return Backend(conf).Setup(conf) } @@ -41,6 +43,8 @@ func Backend(conf *logical.BackendConfig) *databaseBackend { }, Clean: b.closeAllDBs, + + Invalidate: b.invalidate, } b.logger = conf.Logger @@ -123,9 +127,32 @@ func (b *databaseBackend) Role(s logical.Storage, n string) (*roleEntry, error) return &result, nil } +func (b *databaseBackend) invalidate(key string) { + b.Lock() + defer b.Unlock() + + switch { + case strings.HasPrefix(key, databaseConfigPath): + name := strings.TrimPrefix(key, databaseConfigPath) + b.clearConnection(name) + } +} + +// clearConnection closes the database connection and +// removes it from the b.connections map. +func (b *databaseBackend) clearConnection(name string) { + db, ok := b.connections[name] + if ok { + db.Close() + delete(b.connections, name) + } +} + const backendHelp = ` -The PostgreSQL backend dynamically generates database users. +The database backend supports using many different databases +as secret backends, including but not limited to: +cassandra, msslq, mysql, postgres After mounting this backend, configure it using the endpoints within -the "config/" path. +the "database/dbs/" path. ` diff --git a/builtin/logical/database/path_config_connection.go b/builtin/logical/database/path_config_connection.go index a1d32d572f65..be2038c31cab 100644 --- a/builtin/logical/database/path_config_connection.go +++ b/builtin/logical/database/path_config_connection.go @@ -40,13 +40,9 @@ func (b *databaseBackend) pathConnectionReset(req *logical.Request, data *framew b.Lock() defer b.Unlock() - db, ok := b.connections[name] - if ok { - db.Close() - delete(b.connections, name) - } + b.clearConnection(name) - db, err := b.getOrCreateDBObj(req.Storage, name) + _, err := b.getOrCreateDBObj(req.Storage, name) if err != nil { return nil, err } From df944f2d92ab2f4a7e530b623ae7b561612b81a6 Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Tue, 4 Apr 2017 11:33:58 -0700 Subject: [PATCH 052/152] Don't return strings, always structs --- builtin/logical/database/dbs/plugin.go | 44 ++++++++++++++++---------- 1 file changed, 28 insertions(+), 16 deletions(-) diff --git a/builtin/logical/database/dbs/plugin.go b/builtin/logical/database/dbs/plugin.go index 791f3b465195..441f97ca0fd8 100644 --- a/builtin/logical/database/dbs/plugin.go +++ b/builtin/logical/database/dbs/plugin.go @@ -169,24 +169,24 @@ func (dr *databasePluginRPCClient) Close() error { } func (dr *databasePluginRPCClient) GenerateUsername(displayName string) (string, error) { - var username string - err := dr.client.Call("Plugin.GenerateUsername", displayName, &username) + resp := &GenerateUsernameResponse{} + err := dr.client.Call("Plugin.GenerateUsername", displayName, resp) - return username, err + return resp.Username, err } func (dr *databasePluginRPCClient) GeneratePassword() (string, error) { - var password string - err := dr.client.Call("Plugin.GeneratePassword", struct{}{}, &password) + resp := &GeneratePasswordResponse{} + err := dr.client.Call("Plugin.GeneratePassword", struct{}{}, resp) - return password, err + return resp.Password, err } func (dr *databasePluginRPCClient) GenerateExpiration(duration time.Duration) (string, error) { - var expiration string - err := dr.client.Call("Plugin.GenerateExpiration", duration, &expiration) + resp := &GenerateExpirationResponse{} + err := dr.client.Call("Plugin.GenerateExpiration", duration, resp) - return expiration, err + return resp.Expiration, err } // ---- RPC server domain ---- @@ -230,28 +230,28 @@ func (ds *databasePluginRPCServer) Close(_ struct{}, _ *struct{}) error { return nil } -func (ds *databasePluginRPCServer) GenerateUsername(args string, resp *string) error { +func (ds *databasePluginRPCServer) GenerateUsername(args string, resp *GenerateUsernameResponse) error { var err error - *resp, err = ds.impl.GenerateUsername(args) + resp.Username, err = ds.impl.GenerateUsername(args) return err } -func (ds *databasePluginRPCServer) GeneratePassword(_ struct{}, resp *string) error { +func (ds *databasePluginRPCServer) GeneratePassword(_ struct{}, resp *GeneratePasswordResponse) error { var err error - *resp, err = ds.impl.GeneratePassword() + resp.Password, err = ds.impl.GeneratePassword() return err } -func (ds *databasePluginRPCServer) GenerateExpiration(args time.Duration, resp *string) error { +func (ds *databasePluginRPCServer) GenerateExpiration(args time.Duration, resp *GenerateExpirationResponse) error { var err error - *resp, err = ds.impl.GenerateExpiration(args) + resp.Expiration, err = ds.impl.GenerateExpiration(args) return err } -// ---- Request Args domain ---- +// ---- Request Args Domain ---- type CreateUserRequest struct { Statements Statements @@ -270,3 +270,15 @@ type RevokeUserRequest struct { Statements Statements Username string } + +// ---- Response Args Domain ---- + +type GenerateUsernameResponse struct { + Username string +} +type GenerateExpirationResponse struct { + Expiration string +} +type GeneratePasswordResponse struct { + Password string +} From 73a2cdf6a58933ae6d8b88ee6cfe3ae8e98e0bc9 Mon Sep 17 00:00:00 2001 From: Calvin Leung Huang Date: Tue, 4 Apr 2017 17:26:59 -0400 Subject: [PATCH 053/152] Do not mark conn as initialized until the end (#2567) --- builtin/logical/database/dbs/connectionproducer.go | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/builtin/logical/database/dbs/connectionproducer.go b/builtin/logical/database/dbs/connectionproducer.go index b5dc93951777..31ef2853b702 100644 --- a/builtin/logical/database/dbs/connectionproducer.go +++ b/builtin/logical/database/dbs/connectionproducer.go @@ -54,12 +54,13 @@ func (c *sqlConnectionProducer) Initialize(conf map[string]interface{}) error { if err != nil { return err } - c.initalized = true if _, err := c.connection(); err != nil { - return fmt.Errorf("Error Initalizing Connection: %s", err) + return fmt.Errorf("error initalizing connection: %s", err) } + c.initalized = true + return nil } From f6b45bdcfb6b1e309e1fe423ff11a152220b7122 Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Tue, 4 Apr 2017 14:43:39 -0700 Subject: [PATCH 054/152] Execute builtin plugins --- command/server.go | 63 ++++++++++++++++++++++++++++++----------- vault/core.go | 12 +++++++- vault/plugin_catalog.go | 46 +++++++++++++++++++----------- 3 files changed, 87 insertions(+), 34 deletions(-) diff --git a/command/server.go b/command/server.go index d6eb0d76d95e..3b1c771bbeec 100644 --- a/command/server.go +++ b/command/server.go @@ -1,8 +1,10 @@ package command import ( + "crypto/sha256" "encoding/base64" "fmt" + "io" "net" "net/http" "net/url" @@ -131,6 +133,33 @@ func (c *ServerCommand) Run(args []string) int { dev = true } + // Record the vault binary's location and SHA-256 checksum for use in + // builtin plugins. + ex, err := os.Executable() + if err != nil { + c.Ui.Output(fmt.Sprintf( + "Error looking up vault binary: %s", err)) + return 1 + } + + file, err := os.Open(ex) + if err != nil { + c.Ui.Output(fmt.Sprintf( + "Error loading vault binary: %s", err)) + return 1 + } + defer file.Close() + + hash := sha256.New() + _, err = io.Copy(hash, file) + if err != nil { + c.Ui.Output(fmt.Sprintf( + "Error checksumming vault binary: %s", err)) + return 1 + } + + sha256Value := hash.Sum(nil) + // Validation if !dev { switch { @@ -225,21 +254,23 @@ func (c *ServerCommand) Run(args []string) int { } coreConfig := &vault.CoreConfig{ - Physical: backend, - RedirectAddr: config.Backend.RedirectAddr, - HAPhysical: nil, - Seal: seal, - AuditBackends: c.AuditBackends, - CredentialBackends: c.CredentialBackends, - LogicalBackends: c.LogicalBackends, - Logger: c.logger, - DisableCache: config.DisableCache, - DisableMlock: config.DisableMlock, - MaxLeaseTTL: config.MaxLeaseTTL, - DefaultLeaseTTL: config.DefaultLeaseTTL, - ClusterName: config.ClusterName, - CacheSize: config.CacheSize, - PluginDirectory: config.PluginDirectory, + Physical: backend, + RedirectAddr: config.Backend.RedirectAddr, + HAPhysical: nil, + Seal: seal, + AuditBackends: c.AuditBackends, + CredentialBackends: c.CredentialBackends, + LogicalBackends: c.LogicalBackends, + Logger: c.logger, + DisableCache: config.DisableCache, + DisableMlock: config.DisableMlock, + MaxLeaseTTL: config.MaxLeaseTTL, + DefaultLeaseTTL: config.DefaultLeaseTTL, + ClusterName: config.ClusterName, + CacheSize: config.CacheSize, + PluginDirectory: config.PluginDirectory, + VaultBinaryLocation: ex, + VaultBinarySHA256: sha256Value, } if dev { coreConfig.DevToken = devRootTokenID @@ -252,7 +283,7 @@ func (c *ServerCommand) Run(args []string) int { "Error getting user's home directory: %v", err)) return 1 } - coreConfig.PluginDirectory = filepath.Join(homePath, "/vault-plugins/") + coreConfig.PluginDirectory = filepath.Join(homePath, "/.vault-plugins/") } var disableClustering bool diff --git a/vault/core.go b/vault/core.go index 08a828643ad9..ffd36683be59 100644 --- a/vault/core.go +++ b/vault/core.go @@ -335,6 +335,12 @@ type Core struct { // pluginDirectory is the location vault will look for plugins pluginDirectory string + // vaultBinaryLocation is used to run builtin plugins in secure mode + vaultBinaryLocation string + + // vaultBinarySHA256 is used to run builtin plugins in secure mode + vaultBinarySHA256 []byte + // pluginCatalog is used to manage plugin configurations pluginCatalog *PluginCatalog } @@ -381,7 +387,9 @@ type CoreConfig struct { EnableUI bool `json:"ui" structs:"ui" mapstructure:"ui"` - PluginDirectory string `json:"plugin_directory" structs:"plugin_directory" mapstructure:"plugin_directory"` + PluginDirectory string `json:"plugin_directory" structs:"plugin_directory" mapstructure:"plugin_directory"` + VaultBinaryLocation string `json:"vault_binary_location" structs:"vault_binary_location" mapstructure:"vault_binary_location"` + VaultBinarySHA256 []byte `json:"vault_binary_sha256" structs:"vault_binary_sha256" mapstructure:"vault_binary_sha256"` ReloadFuncs *map[string][]ReloadFunc ReloadFuncsLock *sync.RWMutex @@ -439,6 +447,8 @@ func NewCore(conf *CoreConfig) (*Core, error) { clusterName: conf.ClusterName, clusterListenerShutdownCh: make(chan struct{}), clusterListenerShutdownSuccessCh: make(chan struct{}), + vaultBinaryLocation: conf.VaultBinaryLocation, + vaultBinarySHA256: conf.VaultBinarySHA256, } // Wrap the physical backend in a cache layer if enabled and not already wrapped diff --git a/vault/plugin_catalog.go b/vault/plugin_catalog.go index c1f504d2c2be..88265a24521c 100644 --- a/vault/plugin_catalog.go +++ b/vault/plugin_catalog.go @@ -10,50 +10,62 @@ import ( "github.com/hashicorp/vault/helper/jsonutil" "github.com/hashicorp/vault/helper/pluginutil" + "github.com/hashicorp/vault/helper/strutil" "github.com/hashicorp/vault/logical" ) var ( pluginCatalogPrefix = "plugin-catalog/" + builtinPlugins = []string{"mysql-database-plugin", "postgres-database-plugin"} ) type PluginCatalog struct { - catalogView *BarrierView - directory string + catalogView *BarrierView + directory string + vaultCommand string + vaultSHA256 []byte lock sync.RWMutex builtin map[string]*pluginutil.PluginRunner } -func NewPluginCatalog(view *BarrierView, directory string) *PluginCatalog { - return &PluginCatalog{ - catalogView: view.SubView(pluginCatalogPrefix), - directory: directory, - } -} - func (c *Core) setupPluginCatalog() error { - catalog := NewPluginCatalog(c.systemBarrierView, c.pluginDirectory) - c.pluginCatalog = catalog + c.pluginCatalog = &PluginCatalog{ + catalogView: c.systemBarrierView.SubView(pluginCatalogPrefix), + directory: c.pluginDirectory, + vaultCommand: c.vaultBinaryLocation, + vaultSHA256: c.vaultBinarySHA256, + } return nil } func (c *PluginCatalog) Get(name string) (*pluginutil.PluginRunner, error) { + // Look for external plugins in the barrier out, err := c.catalogView.Get(name) if err != nil { return nil, fmt.Errorf("failed to retrieve plugin \"%s\": %v", name, err) } - if out == nil { - return nil, fmt.Errorf("no plugin found with name: %s", name) + if out != nil { + entry := new(pluginutil.PluginRunner) + if err := jsonutil.DecodeJSON(out.Value, entry); err != nil { + return nil, fmt.Errorf("failed to decode plugin entry: %v", err) + } + + return entry, nil } - entry := new(pluginutil.PluginRunner) - if err := jsonutil.DecodeJSON(out.Value, entry); err != nil { - return nil, fmt.Errorf("failed to decode plugin entry: %v", err) + // Look for builtin plugins + if !strutil.StrListContains(builtinPlugins, name) { + return nil, fmt.Errorf("no plugin found with name: %s", name) } - return entry, nil + return &pluginutil.PluginRunner{ + Name: name, + Command: c.vaultCommand, + Args: []string{"plugin-exec", name}, + Sha256: c.vaultSHA256, + }, nil } func (c *PluginCatalog) Set(name, command string, sha256 []byte) error { From 485b331d6affa00fe2890e2f13e2f16fefa3f250 Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Tue, 4 Apr 2017 17:12:02 -0700 Subject: [PATCH 055/152] Add a cli command to run builtin plugins --- cli/commands.go | 6 ++++ command/plugin-exec.go | 71 +++++++++++++++++++++++++++++++++++++++++ vault/plugin_catalog.go | 3 +- 3 files changed, 78 insertions(+), 2 deletions(-) create mode 100644 command/plugin-exec.go diff --git a/cli/commands.go b/cli/commands.go index 13f7c8b25aad..e7545ca906f3 100644 --- a/cli/commands.go +++ b/cli/commands.go @@ -331,5 +331,11 @@ func Commands(metaPtr *meta.Meta) map[string]cli.CommandFactory { Ui: metaPtr.Ui, }, nil }, + + "plugin-exec": func() (cli.Command, error) { + return &command.PluginExec{ + Meta: *metaPtr, + }, nil + }, } } diff --git a/command/plugin-exec.go b/command/plugin-exec.go new file mode 100644 index 000000000000..18dc3e145308 --- /dev/null +++ b/command/plugin-exec.go @@ -0,0 +1,71 @@ +package command + +import ( + "fmt" + "strings" + + "github.com/hashicorp/vault/meta" +) + +type PluginExec struct { + meta.Meta +} + +var builtinFactories = map[string]func() error{ +// "mysql-database-plugin": mysql.Factory, +// "postgres-database-plugin": postgres.Factory, +} + +func (c *PluginExec) Run(args []string) int { + flags := c.Meta.FlagSet("plugin-exec", meta.FlagSetDefault) + flags.Usage = func() { c.Ui.Error(c.Help()) } + if err := flags.Parse(args); err != nil { + return 1 + } + + args = flags.Args() + if len(args) != 1 { + flags.Usage() + c.Ui.Error(fmt.Sprintf( + "\nplugin-exec expects one argument: the plugin to execute.")) + return 1 + } + + pluginName := args[0] + + factory, ok := builtinFactories[pluginName] + if !ok { + c.Ui.Error(fmt.Sprintf( + "No plugin with the name %s found", pluginName)) + return 1 + } + + err := factory() + if err != nil { + c.Ui.Error(fmt.Sprintf( + "Error running plugin: %s", err)) + return 1 + } + + return 0 +} + +func (c *PluginExec) Synopsis() string { + return "Force the Vault node to give up active duty" +} + +func (c *PluginExec) Help() string { + helpText := ` +Usage: vault step-down [options] + + Force the Vault node to step down from active duty. + + This causes the indicated node to give up active status. Note that while the + affected node will have a short delay before attempting to grab the lock + again, if no other node grabs the lock beforehand, it is possible for the + same node to re-grab the lock and become active again. + +General Options: +` + meta.GeneralOptionsUsage() + return strings.TrimSpace(helpText) +} diff --git a/vault/plugin_catalog.go b/vault/plugin_catalog.go index 88265a24521c..eccac2bd1b35 100644 --- a/vault/plugin_catalog.go +++ b/vault/plugin_catalog.go @@ -25,8 +25,7 @@ type PluginCatalog struct { vaultCommand string vaultSHA256 []byte - lock sync.RWMutex - builtin map[string]*pluginutil.PluginRunner + lock sync.RWMutex } func (c *Core) setupPluginCatalog() error { From 8f88452fc09f7d5a33c55c309ce659a973658cc3 Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Wed, 5 Apr 2017 11:00:13 -0700 Subject: [PATCH 056/152] move builtin plugins list to the pluginutil --- command/plugin-exec.go | 23 +++++++++-------------- helper/pluginutil/builtin.go | 6 ++++++ vault/plugin_catalog.go | 4 +--- 3 files changed, 16 insertions(+), 17 deletions(-) create mode 100644 helper/pluginutil/builtin.go diff --git a/command/plugin-exec.go b/command/plugin-exec.go index 18dc3e145308..f0d6a8d51a5d 100644 --- a/command/plugin-exec.go +++ b/command/plugin-exec.go @@ -4,6 +4,7 @@ import ( "fmt" "strings" + "github.com/hashicorp/vault/helper/pluginutil" "github.com/hashicorp/vault/meta" ) @@ -11,11 +12,6 @@ type PluginExec struct { meta.Meta } -var builtinFactories = map[string]func() error{ -// "mysql-database-plugin": mysql.Factory, -// "postgres-database-plugin": postgres.Factory, -} - func (c *PluginExec) Run(args []string) int { flags := c.Meta.FlagSet("plugin-exec", meta.FlagSetDefault) flags.Usage = func() { c.Ui.Error(c.Help()) } @@ -33,14 +29,14 @@ func (c *PluginExec) Run(args []string) int { pluginName := args[0] - factory, ok := builtinFactories[pluginName] + runner, ok := pluginutil.BuiltinPlugins[pluginName] if !ok { c.Ui.Error(fmt.Sprintf( "No plugin with the name %s found", pluginName)) return 1 } - err := factory() + err := runner() if err != nil { c.Ui.Error(fmt.Sprintf( "Error running plugin: %s", err)) @@ -51,19 +47,18 @@ func (c *PluginExec) Run(args []string) int { } func (c *PluginExec) Synopsis() string { - return "Force the Vault node to give up active duty" + return "Runs a builtin plugin. Should only be called by vault." } func (c *PluginExec) Help() string { helpText := ` -Usage: vault step-down [options] +Usage: vault plugin-exec type - Force the Vault node to step down from active duty. + Runs a builtin plugin. Should only be called by vault. - This causes the indicated node to give up active status. Note that while the - affected node will have a short delay before attempting to grab the lock - again, if no other node grabs the lock beforehand, it is possible for the - same node to re-grab the lock and become active again. + This will execute a plugin for use in a plugable location in vault. If run by + a cli user it will print a message indicating it can not be executed by anyone + other than vault. For supported plugin types see the vault documentation. General Options: ` + meta.GeneralOptionsUsage() diff --git a/helper/pluginutil/builtin.go b/helper/pluginutil/builtin.go new file mode 100644 index 000000000000..6a464bb8243a --- /dev/null +++ b/helper/pluginutil/builtin.go @@ -0,0 +1,6 @@ +package pluginutil + +var BuiltinPlugins = map[string]func() error{ +// "mysql-database-plugin": mysql.Run, +// "postgres-database-plugin": postgres.Run, +} diff --git a/vault/plugin_catalog.go b/vault/plugin_catalog.go index eccac2bd1b35..c6e4e4059bba 100644 --- a/vault/plugin_catalog.go +++ b/vault/plugin_catalog.go @@ -10,13 +10,11 @@ import ( "github.com/hashicorp/vault/helper/jsonutil" "github.com/hashicorp/vault/helper/pluginutil" - "github.com/hashicorp/vault/helper/strutil" "github.com/hashicorp/vault/logical" ) var ( pluginCatalogPrefix = "plugin-catalog/" - builtinPlugins = []string{"mysql-database-plugin", "postgres-database-plugin"} ) type PluginCatalog struct { @@ -55,7 +53,7 @@ func (c *PluginCatalog) Get(name string) (*pluginutil.PluginRunner, error) { } // Look for builtin plugins - if !strutil.StrListContains(builtinPlugins, name) { + if _, ok := pluginutil.BuiltinPlugins[name]; !ok { return nil, fmt.Errorf("no plugin found with name: %s", name) } From 8a2e29c607664360ecff7c2f356d830c6b7746f0 Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Wed, 5 Apr 2017 16:20:31 -0700 Subject: [PATCH 057/152] Refactor to use builtin plugins from an external repo --- builtin/logical/database/backend.go | 51 ++- .../database/{dbs => }/databasemiddleware.go | 2 +- builtin/logical/database/dbs/cassandra.go | 108 ----- .../database/dbs/connectionproducer.go | 280 ------------ .../database/dbs/credentialsproducer.go | 83 ---- builtin/logical/database/dbs/db.go | 196 --------- builtin/logical/database/dbs/mssql.go | 219 --------- builtin/logical/database/dbs/mssql_test.go | 221 ---------- builtin/logical/database/dbs/mysql.go | 135 ------ builtin/logical/database/dbs/mysql_test.go | 346 --------------- builtin/logical/database/dbs/postgresql.go | 279 ------------ .../logical/database/dbs/postgresql_test.go | 414 ------------------ .../database/path_config_connection.go | 78 +--- builtin/logical/database/path_roles.go | 11 +- builtin/logical/database/{dbs => }/plugin.go | 44 +- .../logical/database/{dbs => }/plugin_test.go | 2 +- command/plugin-exec.go | 4 +- helper/builtinplugins/builtin.go | 8 + helper/pluginutil/builtin.go | 6 - vault/plugin_catalog.go | 3 +- 20 files changed, 110 insertions(+), 2380 deletions(-) rename builtin/logical/database/{dbs => }/databasemiddleware.go (99%) delete mode 100644 builtin/logical/database/dbs/cassandra.go delete mode 100644 builtin/logical/database/dbs/connectionproducer.go delete mode 100644 builtin/logical/database/dbs/credentialsproducer.go delete mode 100644 builtin/logical/database/dbs/db.go delete mode 100644 builtin/logical/database/dbs/mssql.go delete mode 100644 builtin/logical/database/dbs/mssql_test.go delete mode 100644 builtin/logical/database/dbs/mysql.go delete mode 100644 builtin/logical/database/dbs/mysql_test.go delete mode 100644 builtin/logical/database/dbs/postgresql.go delete mode 100644 builtin/logical/database/dbs/postgresql_test.go rename builtin/logical/database/{dbs => }/plugin.go (88%) rename builtin/logical/database/{dbs => }/plugin_test.go (99%) create mode 100644 helper/builtinplugins/builtin.go delete mode 100644 helper/pluginutil/builtin.go diff --git a/builtin/logical/database/backend.go b/builtin/logical/database/backend.go index 4d069a432899..a2fff4ba863d 100644 --- a/builtin/logical/database/backend.go +++ b/builtin/logical/database/backend.go @@ -4,16 +4,52 @@ import ( "fmt" "strings" "sync" + "time" log "github.com/mgutz/logxi/v1" - "github.com/hashicorp/vault/builtin/logical/database/dbs" "github.com/hashicorp/vault/logical" "github.com/hashicorp/vault/logical/framework" ) const databaseConfigPath = "database/dbs/" +// DatabaseType is the interface that all database objects must implement. +type DatabaseType interface { + Type() string + CreateUser(statements Statements, username, password, expiration string) error + RenewUser(statements Statements, username, expiration string) error + RevokeUser(statements Statements, username string) error + + Initialize(map[string]interface{}) error + Close() error + + GenerateUsername(displayName string) (string, error) + GeneratePassword() (string, error) + GenerateExpiration(ttl time.Duration) (string, error) +} + +// DatabaseConfig is used by the Factory function to configure a DatabaseType +// object. +type DatabaseConfig struct { + PluginName string `json:"plugin_name" structs:"plugin_name" mapstructure:"plugin_name"` + // ConnectionDetails stores the database specific connection settings needed + // by each database type. + ConnectionDetails map[string]interface{} `json:"connection_details" structs:"connection_details" mapstructure:"connection_details"` + MaxOpenConnections int `json:"max_open_connections" structs:"max_open_connections" mapstructure:"max_open_connections"` + MaxIdleConnections int `json:"max_idle_connections" structs:"max_idle_connections" mapstructure:"max_idle_connections"` + MaxConnectionLifetime time.Duration `json:"max_connection_lifetime" structs:"max_connection_lifetime" mapstructure:"max_connection_lifetime"` +} + +// Statements set in role creation and passed into the database type's functions. +// TODO: Add a way of setting defaults here. +type Statements struct { + CreationStatements string `json:"creation_statments" mapstructure:"creation_statements" structs:"creation_statments"` + RevocationStatements string `json:"revocation_statements" mapstructure:"revocation_statements" structs:"revocation_statements"` + RollbackStatements string `json:"rollback_statements" mapstructure:"rollback_statements" structs:"rollback_statements"` + RenewStatements string `json:"renew_statements" mapstructure:"renew_statements" structs:"renew_statements"` +} + func Factory(conf *logical.BackendConfig) (logical.Backend, error) { return Backend(conf).Setup(conf) } @@ -30,7 +66,6 @@ func Backend(conf *logical.BackendConfig) *databaseBackend { }, Paths: []*framework.Path{ - pathConfigureBuiltinConnection(&b), pathConfigurePluginConnection(&b), pathListRoles(&b), pathRoles(&b), @@ -48,12 +83,12 @@ func Backend(conf *logical.BackendConfig) *databaseBackend { } b.logger = conf.Logger - b.connections = make(map[string]dbs.DatabaseType) + b.connections = make(map[string]DatabaseType) return &b } type databaseBackend struct { - connections map[string]dbs.DatabaseType + connections map[string]DatabaseType logger log.Logger *framework.Backend @@ -73,7 +108,7 @@ func (b *databaseBackend) closeAllDBs() { // This function is used to retrieve a database object either from the cached // connection map or by using the database config in storage. The caller of this // function needs to hold the backend's lock. -func (b *databaseBackend) getOrCreateDBObj(s logical.Storage, name string) (dbs.DatabaseType, error) { +func (b *databaseBackend) getOrCreateDBObj(s logical.Storage, name string) (DatabaseType, error) { // if the object already is built and cached, return it db, ok := b.connections[name] if ok { @@ -88,14 +123,12 @@ func (b *databaseBackend) getOrCreateDBObj(s logical.Storage, name string) (dbs. return nil, fmt.Errorf("failed to find entry for connection with name: %s", name) } - var config dbs.DatabaseConfig + var config DatabaseConfig if err := entry.DecodeJSON(&config); err != nil { return nil, err } - factory := config.GetFactory() - - db, err = factory(&config, b.System(), b.logger) + db, err = PluginFactory(&config, b.System(), b.logger) if err != nil { return nil, err } diff --git a/builtin/logical/database/dbs/databasemiddleware.go b/builtin/logical/database/databasemiddleware.go similarity index 99% rename from builtin/logical/database/dbs/databasemiddleware.go rename to builtin/logical/database/databasemiddleware.go index d3f037ecbca3..5892e8064a37 100644 --- a/builtin/logical/database/dbs/databasemiddleware.go +++ b/builtin/logical/database/databasemiddleware.go @@ -1,4 +1,4 @@ -package dbs +package database import ( "time" diff --git a/builtin/logical/database/dbs/cassandra.go b/builtin/logical/database/dbs/cassandra.go deleted file mode 100644 index 1be26766bdcf..000000000000 --- a/builtin/logical/database/dbs/cassandra.go +++ /dev/null @@ -1,108 +0,0 @@ -package dbs - -import ( - "fmt" - "strings" - - "github.com/gocql/gocql" - "github.com/hashicorp/vault/helper/strutil" -) - -const ( - defaultCreationCQL = `CREATE USER '{{username}}' WITH PASSWORD '{{password}}' NOSUPERUSER;` - defaultRollbackCQL = `DROP USER '{{username}}';` -) - -type Cassandra struct { - // Session is goroutine safe, however, since we reinitialize - // it when connection info changes, we want to make sure we - // can close it and use a new connection; hence the lock - ConnectionProducer - CredentialsProducer -} - -func (c *Cassandra) Type() string { - return cassandraTypeName -} - -func (c *Cassandra) getConnection() (*gocql.Session, error) { - session, err := c.connection() - if err != nil { - return nil, err - } - - return session.(*gocql.Session), nil -} - -func (c *Cassandra) CreateUser(statements Statements, username, password, expiration string) error { - // Grab the lock - c.Lock() - defer c.Unlock() - - // Get the connection - session, err := c.getConnection() - if err != nil { - return err - } - - creationCQL := statements.CreationStatements - if creationCQL == "" { - creationCQL = defaultCreationCQL - } - rollbackCQL := statements.RollbackStatements - if rollbackCQL == "" { - rollbackCQL = defaultRollbackCQL - } - - // Execute each query - for _, query := range strutil.ParseArbitraryStringSlice(creationCQL, ";") { - query = strings.TrimSpace(query) - if len(query) == 0 { - continue - } - - err = session.Query(queryHelper(query, map[string]string{ - "username": username, - "password": password, - })).Exec() - if err != nil { - for _, query := range strutil.ParseArbitraryStringSlice(rollbackCQL, ";") { - query = strings.TrimSpace(query) - if len(query) == 0 { - continue - } - - session.Query(queryHelper(query, map[string]string{ - "username": username, - "password": password, - })).Exec() - } - return err - } - } - - return nil -} - -func (c *Cassandra) RenewUser(statements Statements, username, expiration string) error { - // NOOP - return nil -} - -func (c *Cassandra) RevokeUser(statements Statements, username string) error { - // Grab the lock - c.Lock() - defer c.Unlock() - - session, err := c.getConnection() - if err != nil { - return err - } - - err = session.Query(fmt.Sprintf("DROP USER '%s'", username)).Exec() - if err != nil { - return fmt.Errorf("error removing user %s", username) - } - - return nil -} diff --git a/builtin/logical/database/dbs/connectionproducer.go b/builtin/logical/database/dbs/connectionproducer.go deleted file mode 100644 index 31ef2853b702..000000000000 --- a/builtin/logical/database/dbs/connectionproducer.go +++ /dev/null @@ -1,280 +0,0 @@ -package dbs - -import ( - "crypto/tls" - "database/sql" - "errors" - "fmt" - "strings" - "sync" - "time" - - // Import sql drivers - _ "github.com/denisenkom/go-mssqldb" - _ "github.com/go-sql-driver/mysql" - _ "github.com/lib/pq" - "github.com/mitchellh/mapstructure" - - "github.com/gocql/gocql" - "github.com/hashicorp/vault/helper/certutil" - "github.com/hashicorp/vault/helper/tlsutil" -) - -var ( - errNotInitalized = errors.New("connection has not been initalized") -) - -// ConnectionProducer can be used as an embeded interface in the DatabaseType -// definition. It implements the methods dealing with individual database -// connections and is used in all the builtin database types. -type ConnectionProducer interface { - Close() error - Initialize(map[string]interface{}) error - - sync.Locker - connection() (interface{}, error) -} - -// sqlConnectionProducer implements ConnectionProducer and provides a generic producer for most sql databases -type sqlConnectionProducer struct { - ConnectionURL string `json:"connection_url" structs:"connection_url" mapstructure:"connection_url"` - - config *DatabaseConfig - - initalized bool - db *sql.DB - sync.Mutex -} - -func (c *sqlConnectionProducer) Initialize(conf map[string]interface{}) error { - c.Lock() - defer c.Unlock() - - err := mapstructure.Decode(conf, c) - if err != nil { - return err - } - - if _, err := c.connection(); err != nil { - return fmt.Errorf("error initalizing connection: %s", err) - } - - c.initalized = true - - return nil -} - -func (c *sqlConnectionProducer) connection() (interface{}, error) { - // If we already have a DB, test it and return - if c.db != nil { - if err := c.db.Ping(); err == nil { - return c.db, nil - } - // If the ping was unsuccessful, close it and ignore errors as we'll be - // reestablishing anyways - c.db.Close() - } - - // For mssql backend, switch to sqlserver instead - dbType := c.config.DatabaseType - if c.config.DatabaseType == "mssql" { - dbType = "sqlserver" - } - - // Otherwise, attempt to make connection - conn := c.ConnectionURL - - // Ensure timezone is set to UTC for all the conenctions - if strings.HasPrefix(conn, "postgres://") || strings.HasPrefix(conn, "postgresql://") { - if strings.Contains(conn, "?") { - conn += "&timezone=utc" - } else { - conn += "?timezone=utc" - } - } - - var err error - c.db, err = sql.Open(dbType, conn) - if err != nil { - return nil, err - } - - // Set some connection pool settings. We don't need much of this, - // since the request rate shouldn't be high. - c.db.SetMaxOpenConns(c.config.MaxOpenConnections) - c.db.SetMaxIdleConns(c.config.MaxIdleConnections) - c.db.SetConnMaxLifetime(c.config.MaxConnectionLifetime) - - return c.db, nil -} - -func (c *sqlConnectionProducer) Close() error { - // Grab the write lock - c.Lock() - defer c.Unlock() - - if c.db != nil { - c.db.Close() - } - - c.db = nil - - return nil -} - -// cassandraConnectionProducer implements ConnectionProducer and provides an -// interface for cassandra databases to make connections. -type cassandraConnectionProducer struct { - Hosts string `json:"hosts" structs:"hosts" mapstructure:"hosts"` - Username string `json:"username" structs:"username" mapstructure:"username"` - Password string `json:"password" structs:"password" mapstructure:"password"` - TLS bool `json:"tls" structs:"tls" mapstructure:"tls"` - InsecureTLS bool `json:"insecure_tls" structs:"insecure_tls" mapstructure:"insecure_tls"` - Certificate string `json:"certificate" structs:"certificate" mapstructure:"certificate"` - PrivateKey string `json:"private_key" structs:"private_key" mapstructure:"private_key"` - IssuingCA string `json:"issuing_ca" structs:"issuing_ca" mapstructure:"issuing_ca"` - ProtocolVersion int `json:"protocol_version" structs:"protocol_version" mapstructure:"protocol_version"` - ConnectTimeout int `json:"connect_timeout" structs:"connect_timeout" mapstructure:"connect_timeout"` - TLSMinVersion string `json:"tls_min_version" structs:"tls_min_version" mapstructure:"tls_min_version"` - Consistency string `json:"consistency" structs:"consistency" mapstructure:"consistency"` - - config *DatabaseConfig - initalized bool - session *gocql.Session - sync.Mutex -} - -func (c *cassandraConnectionProducer) Initialize(conf map[string]interface{}) error { - c.Lock() - defer c.Unlock() - - err := mapstructure.Decode(conf, c) - if err != nil { - return err - } - c.initalized = true - - if _, err := c.connection(); err != nil { - return fmt.Errorf("error Initalizing Connection: %s", err) - } - - return nil -} - -func (c *cassandraConnectionProducer) connection() (interface{}, error) { - if !c.initalized { - return nil, errNotInitalized - } - - // If we already have a DB, return it - if c.session != nil { - return c.session, nil - } - - session, err := c.createSession() - if err != nil { - return nil, err - } - - // Store the session in backend for reuse - c.session = session - - return session, nil -} - -func (c *cassandraConnectionProducer) Close() error { - // Grab the write lock - c.Lock() - defer c.Unlock() - - if c.session != nil { - c.session.Close() - } - - c.session = nil - - return nil -} - -func (c *cassandraConnectionProducer) createSession() (*gocql.Session, error) { - clusterConfig := gocql.NewCluster(strings.Split(c.Hosts, ",")...) - clusterConfig.Authenticator = gocql.PasswordAuthenticator{ - Username: c.Username, - Password: c.Password, - } - - clusterConfig.ProtoVersion = c.ProtocolVersion - if clusterConfig.ProtoVersion == 0 { - clusterConfig.ProtoVersion = 2 - } - - clusterConfig.Timeout = time.Duration(c.ConnectTimeout) * time.Second - - if c.TLS { - var tlsConfig *tls.Config - if len(c.Certificate) > 0 || len(c.IssuingCA) > 0 { - if len(c.Certificate) > 0 && len(c.PrivateKey) == 0 { - return nil, fmt.Errorf("Found certificate for TLS authentication but no private key") - } - - certBundle := &certutil.CertBundle{} - if len(c.Certificate) > 0 { - certBundle.Certificate = c.Certificate - certBundle.PrivateKey = c.PrivateKey - } - if len(c.IssuingCA) > 0 { - certBundle.IssuingCA = c.IssuingCA - } - - parsedCertBundle, err := certBundle.ToParsedCertBundle() - if err != nil { - return nil, fmt.Errorf("failed to parse certificate bundle: %s", err) - } - - tlsConfig, err = parsedCertBundle.GetTLSConfig(certutil.TLSClient) - if err != nil || tlsConfig == nil { - return nil, fmt.Errorf("failed to get TLS configuration: tlsConfig:%#v err:%v", tlsConfig, err) - } - tlsConfig.InsecureSkipVerify = c.InsecureTLS - - if c.TLSMinVersion != "" { - var ok bool - tlsConfig.MinVersion, ok = tlsutil.TLSLookup[c.TLSMinVersion] - if !ok { - return nil, fmt.Errorf("invalid 'tls_min_version' in config") - } - } else { - // MinVersion was not being set earlier. Reset it to - // zero to gracefully handle upgrades. - tlsConfig.MinVersion = 0 - } - } - - clusterConfig.SslOpts = &gocql.SslOptions{ - Config: *tlsConfig, - } - } - - session, err := clusterConfig.CreateSession() - if err != nil { - return nil, fmt.Errorf("error creating session: %s", err) - } - - // Set consistency - if c.Consistency != "" { - consistencyValue, err := gocql.ParseConsistencyWrapper(c.Consistency) - if err != nil { - return nil, err - } - - session.SetConsistency(consistencyValue) - } - - // Verify the info - err = session.Query(`LIST USERS`).Exec() - if err != nil { - return nil, fmt.Errorf("error validating connection info: %s", err) - } - - return session, nil -} diff --git a/builtin/logical/database/dbs/credentialsproducer.go b/builtin/logical/database/dbs/credentialsproducer.go deleted file mode 100644 index 6bd543f4e175..000000000000 --- a/builtin/logical/database/dbs/credentialsproducer.go +++ /dev/null @@ -1,83 +0,0 @@ -package dbs - -import ( - "fmt" - "strings" - "time" - - uuid "github.com/hashicorp/go-uuid" -) - -// CredentialsProducer can be used as an embeded interface in the DatabaseType -// definition. It implements the methods for generating user information for a -// particular database type and is used in all the builtin database types. -type CredentialsProducer interface { - GenerateUsername(displayName string) (string, error) - GeneratePassword() (string, error) - GenerateExpiration(ttl time.Duration) (string, error) -} - -// sqlCredentialsProducer implements CredentialsProducer and provides a generic credentials producer for most sql database types. -type sqlCredentialsProducer struct { - displayNameLen int - usernameLen int -} - -func (scp *sqlCredentialsProducer) GenerateUsername(displayName string) (string, error) { - if scp.displayNameLen > 0 && len(displayName) > scp.displayNameLen { - displayName = displayName[:scp.displayNameLen] - } - userUUID, err := uuid.GenerateUUID() - if err != nil { - return "", err - } - username := fmt.Sprintf("%s-%s", displayName, userUUID) - if scp.usernameLen > 0 && len(username) > scp.usernameLen { - username = username[:scp.usernameLen] - } - - return username, nil -} - -func (scp *sqlCredentialsProducer) GeneratePassword() (string, error) { - password, err := uuid.GenerateUUID() - if err != nil { - return "", err - } - - return password, nil -} - -func (scp *sqlCredentialsProducer) GenerateExpiration(ttl time.Duration) (string, error) { - return time.Now(). - Add(ttl). - Format("2006-01-02 15:04:05-0700"), nil -} - -// cassandraCredentialsProducer implements CredentialsProducer and provides an -// interface for cassandra databases to generate user information. -type cassandraCredentialsProducer struct{} - -func (ccp *cassandraCredentialsProducer) GenerateUsername(displayName string) (string, error) { - userUUID, err := uuid.GenerateUUID() - if err != nil { - return "", err - } - username := fmt.Sprintf("vault_%s_%s_%d", displayName, userUUID, time.Now().Unix()) - username = strings.Replace(username, "-", "_", -1) - - return username, nil -} - -func (ccp *cassandraCredentialsProducer) GeneratePassword() (string, error) { - password, err := uuid.GenerateUUID() - if err != nil { - return "", err - } - - return password, nil -} - -func (ccp *cassandraCredentialsProducer) GenerateExpiration(ttl time.Duration) (string, error) { - return "", nil -} diff --git a/builtin/logical/database/dbs/db.go b/builtin/logical/database/dbs/db.go deleted file mode 100644 index 49b18b3b858a..000000000000 --- a/builtin/logical/database/dbs/db.go +++ /dev/null @@ -1,196 +0,0 @@ -package dbs - -import ( - "errors" - "fmt" - "strings" - "time" - - "github.com/hashicorp/vault/logical" - log "github.com/mgutz/logxi/v1" -) - -const ( - postgreSQLTypeName = "postgres" - mySQLTypeName = "mysql" - msSQLTypeName = "mssql" - cassandraTypeName = "cassandra" - pluginTypeName = "plugin" -) - -var ( - ErrUnsupportedDatabaseType = errors.New("unsupported database type") - ErrEmptyCreationStatement = errors.New("empty creation statements") - ErrEmptyPluginName = errors.New("empty plugin name") -) - -// Factory function definition -type Factory func(*DatabaseConfig, logical.SystemView, log.Logger) (DatabaseType, error) - -// BuiltinFactory is used to build builtin database types. It wraps the database -// object in a logging and metrics middleware. -func BuiltinFactory(conf *DatabaseConfig, sys logical.SystemView, logger log.Logger) (DatabaseType, error) { - var dbType DatabaseType - - switch conf.DatabaseType { - case postgreSQLTypeName: - connProducer := &sqlConnectionProducer{} - connProducer.config = conf - - credsProducer := &sqlCredentialsProducer{ - displayNameLen: 23, - usernameLen: 63, - } - - dbType = &PostgreSQL{ - ConnectionProducer: connProducer, - CredentialsProducer: credsProducer, - } - - case mySQLTypeName: - connProducer := &sqlConnectionProducer{} - connProducer.config = conf - - credsProducer := &sqlCredentialsProducer{ - displayNameLen: 4, - usernameLen: 16, - } - - dbType = &MySQL{ - ConnectionProducer: connProducer, - CredentialsProducer: credsProducer, - } - - case msSQLTypeName: - connProducer := &sqlConnectionProducer{} - connProducer.config = conf - - credsProducer := &sqlCredentialsProducer{ - displayNameLen: 10, - usernameLen: 63, - } - - dbType = &MSSQL{ - ConnectionProducer: connProducer, - CredentialsProducer: credsProducer, - } - - case cassandraTypeName: - connProducer := &cassandraConnectionProducer{} - connProducer.config = conf - - credsProducer := &cassandraCredentialsProducer{} - - dbType = &Cassandra{ - ConnectionProducer: connProducer, - CredentialsProducer: credsProducer, - } - - default: - return nil, ErrUnsupportedDatabaseType - } - - // Wrap with metrics middleware - dbType = &databaseMetricsMiddleware{ - next: dbType, - typeStr: dbType.Type(), - } - - // Wrap with tracing middleware - dbType = &databaseTracingMiddleware{ - next: dbType, - typeStr: dbType.Type(), - logger: logger, - } - - return dbType, nil -} - -// PluginFactory is used to build plugin database types. It wraps the database -// object in a logging and metrics middleware. -func PluginFactory(conf *DatabaseConfig, sys logical.SystemView, logger log.Logger) (DatabaseType, error) { - if conf.PluginName == "" { - return nil, ErrEmptyPluginName - } - - pluginMeta, err := sys.LookupPlugin(conf.PluginName) - if err != nil { - return nil, err - } - - // Make sure the database type is set to plugin - conf.DatabaseType = pluginTypeName - - db, err := newPluginClient(sys, pluginMeta) - if err != nil { - return nil, err - } - - // Wrap with metrics middleware - db = &databaseMetricsMiddleware{ - next: db, - typeStr: db.Type(), - } - - // Wrap with tracing middleware - db = &databaseTracingMiddleware{ - next: db, - typeStr: db.Type(), - logger: logger, - } - - return db, nil -} - -// DatabaseType is the interface that all database objects must implement. -type DatabaseType interface { - Type() string - CreateUser(statements Statements, username, password, expiration string) error - RenewUser(statements Statements, username, expiration string) error - RevokeUser(statements Statements, username string) error - - Initialize(map[string]interface{}) error - Close() error - CredentialsProducer -} - -// DatabaseConfig is used by the Factory function to configure a DatabaseType -// object. -type DatabaseConfig struct { - DatabaseType string `json:"type" structs:"type" mapstructure:"type"` - // ConnectionDetails stores the database specific connection settings needed - // by each database type. - ConnectionDetails map[string]interface{} `json:"connection_details" structs:"connection_details" mapstructure:"connection_details"` - MaxOpenConnections int `json:"max_open_connections" structs:"max_open_connections" mapstructure:"max_open_connections"` - MaxIdleConnections int `json:"max_idle_connections" structs:"max_idle_connections" mapstructure:"max_idle_connections"` - MaxConnectionLifetime time.Duration `json:"max_connection_lifetime" structs:"max_connection_lifetime" mapstructure:"max_connection_lifetime"` - PluginName string `json:"plugin_name" structs:"plugin_name" mapstructure:"plugin_name"` -} - -// GetFactory returns the appropriate factory method for the given database -// type. -func (dc *DatabaseConfig) GetFactory() Factory { - if dc.DatabaseType == pluginTypeName { - return PluginFactory - } - - return BuiltinFactory -} - -// Statements set in role creation and passed into the database type's functions. -// TODO: Add a way of setting defaults here. -type Statements struct { - CreationStatements string `json:"creation_statments" mapstructure:"creation_statements" structs:"creation_statments"` - RevocationStatements string `json:"revocation_statements" mapstructure:"revocation_statements" structs:"revocation_statements"` - RollbackStatements string `json:"rollback_statements" mapstructure:"rollback_statements" structs:"rollback_statements"` - RenewStatements string `json:"renew_statements" mapstructure:"renew_statements" structs:"renew_statements"` -} - -// Query templates a query for us. -func queryHelper(tpl string, data map[string]string) string { - for k, v := range data { - tpl = strings.Replace(tpl, fmt.Sprintf("{{%s}}", k), v, -1) - } - - return tpl -} diff --git a/builtin/logical/database/dbs/mssql.go b/builtin/logical/database/dbs/mssql.go deleted file mode 100644 index b7439b0a82ca..000000000000 --- a/builtin/logical/database/dbs/mssql.go +++ /dev/null @@ -1,219 +0,0 @@ -package dbs - -import ( - "database/sql" - "fmt" - "strings" - - "github.com/hashicorp/vault/helper/strutil" -) - -// MSSQL is an implementation of DatabaseType interface -type MSSQL struct { - ConnectionProducer - CredentialsProducer -} - -// Type returns the TypeName for this backend -func (m *MSSQL) Type() string { - return msSQLTypeName -} - -func (m *MSSQL) getConnection() (*sql.DB, error) { - db, err := m.connection() - if err != nil { - return nil, err - } - - return db.(*sql.DB), nil -} - -// CreateUser generates the username/password on the underlying MSSQL secret backend as instructed by -// the CreationStatement provided. -func (m *MSSQL) CreateUser(statements Statements, username, password, expiration string) error { - // Grab the lock - m.Lock() - defer m.Unlock() - - // Get the connection - db, err := m.getConnection() - if err != nil { - return err - } - - if statements.CreationStatements == "" { - return ErrEmptyCreationStatement - } - - // Start a transaction - tx, err := db.Begin() - if err != nil { - return err - } - defer tx.Rollback() - - // Execute each query - for _, query := range strutil.ParseArbitraryStringSlice(statements.CreationStatements, ";") { - query = strings.TrimSpace(query) - if len(query) == 0 { - continue - } - - stmt, err := tx.Prepare(queryHelper(query, map[string]string{ - "name": username, - "password": password, - })) - if err != nil { - return err - } - defer stmt.Close() - if _, err := stmt.Exec(); err != nil { - return err - } - } - - // Commit the transaction - if err := tx.Commit(); err != nil { - return err - } - - return nil -} - -// RenewUser is not supported on MSSQL, so this is a no-op. -func (m *MSSQL) RenewUser(statements Statements, username, expiration string) error { - // NOOP - return nil -} - -// RevokeUser attempts to drop the specified user. It will first attempt to disable login, -// then kill pending connections from that user, and finally drop the user and login from the -// database instance. -func (m *MSSQL) RevokeUser(statements Statements, username string) error { - // Get connection - db, err := m.getConnection() - if err != nil { - return err - } - - // First disable server login - disableStmt, err := db.Prepare(fmt.Sprintf("ALTER LOGIN [%s] DISABLE;", username)) - if err != nil { - return err - } - defer disableStmt.Close() - if _, err := disableStmt.Exec(); err != nil { - return err - } - - // Query for sessions for the login so that we can kill any outstanding - // sessions. There cannot be any active sessions before we drop the logins - // This isn't done in a transaction because even if we fail along the way, - // we want to remove as much access as possible - sessionStmt, err := db.Prepare(fmt.Sprintf( - "SELECT session_id FROM sys.dm_exec_sessions WHERE login_name = '%s';", username)) - if err != nil { - return err - } - defer sessionStmt.Close() - - sessionRows, err := sessionStmt.Query() - if err != nil { - return err - } - defer sessionRows.Close() - - var revokeStmts []string - for sessionRows.Next() { - var sessionID int - err = sessionRows.Scan(&sessionID) - if err != nil { - return err - } - revokeStmts = append(revokeStmts, fmt.Sprintf("KILL %d;", sessionID)) - } - - // Query for database users using undocumented stored procedure for now since - // it is the easiest way to get this information; - // we need to drop the database users before we can drop the login and the role - // This isn't done in a transaction because even if we fail along the way, - // we want to remove as much access as possible - stmt, err := db.Prepare(fmt.Sprintf("EXEC sp_msloginmappings '%s';", username)) - if err != nil { - return err - } - defer stmt.Close() - - rows, err := stmt.Query() - if err != nil { - return err - } - defer rows.Close() - - for rows.Next() { - var loginName, dbName, qUsername string - var aliasName sql.NullString - err = rows.Scan(&loginName, &dbName, &qUsername, &aliasName) - if err != nil { - return err - } - revokeStmts = append(revokeStmts, fmt.Sprintf(dropUserSQL, dbName, username, username)) - } - - // we do not stop on error, as we want to remove as - // many permissions as possible right now - var lastStmtError error - for _, query := range revokeStmts { - stmt, err := db.Prepare(query) - if err != nil { - lastStmtError = err - continue - } - defer stmt.Close() - _, err = stmt.Exec() - if err != nil { - lastStmtError = err - } - } - - // can't drop if not all database users are dropped - if rows.Err() != nil { - return fmt.Errorf("cound not generate sql statements for all rows: %s", rows.Err()) - } - if lastStmtError != nil { - return fmt.Errorf("could not perform all sql statements: %s", lastStmtError) - } - - // Drop this login - stmt, err = db.Prepare(fmt.Sprintf(dropLoginSQL, username, username)) - if err != nil { - return err - } - defer stmt.Close() - if _, err := stmt.Exec(); err != nil { - return err - } - - return nil -} - -const dropUserSQL = ` -USE [%s] -IF EXISTS - (SELECT name - FROM sys.database_principals - WHERE name = N'%s') -BEGIN - DROP USER [%s] -END -` - -const dropLoginSQL = ` -IF EXISTS - (SELECT name - FROM master.sys.server_principals - WHERE name = N'%s') -BEGIN - DROP LOGIN [%s] -END -` diff --git a/builtin/logical/database/dbs/mssql_test.go b/builtin/logical/database/dbs/mssql_test.go deleted file mode 100644 index f2169299fa60..000000000000 --- a/builtin/logical/database/dbs/mssql_test.go +++ /dev/null @@ -1,221 +0,0 @@ -package dbs - -import ( - "database/sql" - "fmt" - "os" - "sync" - "testing" - "time" - - _ "github.com/denisenkom/go-mssqldb" - log "github.com/mgutz/logxi/v1" - dockertest "gopkg.in/ory-am/dockertest.v3" -) - -var ( - testMSQLImagePull sync.Once -) - -func prepareMSSQLTestContainer(t *testing.T) (cleanup func(), retURL string) { - if os.Getenv("MSSQL_URL") != "" { - return func() {}, os.Getenv("MSSQL_URL") - } - - pool, err := dockertest.NewPool("") - if err != nil { - t.Fatalf("Failed to connect to docker: %s", err) - } - - resource, err := pool.Run("microsoft/mssql-server-linux", "latest", []string{"ACCEPT_EULA=Y", "SA_PASSWORD=yourStrong(!)Password"}) - if err != nil { - t.Fatalf("Could not start local MSSQL docker container: %s", err) - } - - cleanup = func() { - err := pool.Purge(resource) - if err != nil { - t.Fatalf("Failed to cleanup local DynamoDB: %s", err) - } - } - - retURL = fmt.Sprintf("sqlserver://sa:yourStrong(!)Password@localhost:%s", resource.GetPort("1433/tcp")) - - // exponential backoff-retry, because the mssql container may not be able to accept connections yet - if err = pool.Retry(func() error { - var err error - var db *sql.DB - db, err = sql.Open("mssql", retURL) - if err != nil { - return err - } - return db.Ping() - }); err != nil { - t.Fatalf("Could not connect to MSSQL docker container: %s", err) - } - - return -} - -func TestMSSQL_Initialize(t *testing.T) { - cleanup, connURL := prepareMSSQLTestContainer(t) - defer cleanup() - - conf := &DatabaseConfig{ - DatabaseType: msSQLTypeName, - ConnectionDetails: map[string]interface{}{ - "connection_url": connURL, - }, - } - - dbRaw, err := BuiltinFactory(conf, nil, &log.NullLogger{}) - if err != nil { - t.Fatalf("err: %s", err) - } - - // Deconsturct the middleware chain to get the underlying mssql object - dbTracer := dbRaw.(*databaseTracingMiddleware) - dbMetrics := dbTracer.next.(*databaseMetricsMiddleware) - db := dbMetrics.next.(*MSSQL) - connProducer := db.ConnectionProducer.(*sqlConnectionProducer) - - err = dbRaw.Initialize(conf.ConnectionDetails) - if err != nil { - t.Fatalf("err: %s", err) - } - - if !connProducer.initalized { - t.Fatal("Database should be initalized") - } - - err = dbRaw.Close() - if err != nil { - t.Fatalf("err: %s", err) - } - - if connProducer.db != nil { - t.Fatal("db object should be nil") - } -} - -func TestMSSQL_CreateUser(t *testing.T) { - cleanup, connURL := prepareMSSQLTestContainer(t) - defer cleanup() - - conf := &DatabaseConfig{ - DatabaseType: msSQLTypeName, - ConnectionDetails: map[string]interface{}{ - "connection_url": connURL, - }, - } - - db, err := BuiltinFactory(conf, nil, &log.NullLogger{}) - if err != nil { - t.Fatalf("err: %s", err) - } - err = db.Initialize(conf.ConnectionDetails) - if err != nil { - t.Fatalf("err: %s", err) - } - - username, err := db.GenerateUsername("test") - if err != nil { - t.Fatalf("err: %s", err) - } - - password, err := db.GeneratePassword() - if err != nil { - t.Fatalf("err: %s", err) - } - - expiration, err := db.GenerateExpiration(time.Minute) - if err != nil { - t.Fatalf("err: %s", err) - } - - // Test with no configured Creation Statememt - err = db.CreateUser(Statements{}, username, password, expiration) - if err == nil { - t.Fatal("Expected error when no creation statement is provided") - } - - statements := Statements{ - CreationStatements: testMSSQLRole, - } - - err = db.CreateUser(statements, username, password, expiration) - if err != nil { - t.Fatalf("err: %s", err) - } - - username, err = db.GenerateUsername("test") - if err != nil { - t.Fatalf("err: %s", err) - } - - password, err = db.GeneratePassword() - if err != nil { - t.Fatalf("err: %s", err) - } - - expiration, err = db.GenerateExpiration(time.Minute) - if err != nil { - t.Fatalf("err: %s", err) - } -} - -func TestMSSQL_RevokeUser(t *testing.T) { - cleanup, connURL := prepareMSSQLTestContainer(t) - defer cleanup() - - conf := &DatabaseConfig{ - DatabaseType: msSQLTypeName, - ConnectionDetails: map[string]interface{}{ - "connection_url": connURL, - }, - } - - db, err := BuiltinFactory(conf, nil, &log.NullLogger{}) - if err != nil { - t.Fatalf("err: %s", err) - } - err = db.Initialize(conf.ConnectionDetails) - if err != nil { - t.Fatalf("err: %s", err) - } - - username, err := db.GenerateUsername("test") - if err != nil { - t.Fatalf("err: %s", err) - } - - password, err := db.GeneratePassword() - if err != nil { - t.Fatalf("err: %s", err) - } - - expiration, err := db.GenerateExpiration(time.Minute) - if err != nil { - t.Fatalf("err: %s", err) - } - - statements := Statements{ - CreationStatements: testMSSQLRole, - } - - err = db.CreateUser(statements, username, password, expiration) - if err != nil { - t.Fatalf("err: %s", err) - } - - // Test default revoke statememts - err = db.RevokeUser(statements, username) - if err != nil { - t.Fatalf("err: %s", err) - } -} - -const testMSSQLRole = ` -CREATE LOGIN [{{name}}] WITH PASSWORD = '{{password}}'; -CREATE USER [{{name}}] FOR LOGIN [{{name}}]; -GRANT SELECT, INSERT, UPDATE, DELETE ON SCHEMA::dbo TO [{{name}}];` diff --git a/builtin/logical/database/dbs/mysql.go b/builtin/logical/database/dbs/mysql.go deleted file mode 100644 index 54940d8f65fc..000000000000 --- a/builtin/logical/database/dbs/mysql.go +++ /dev/null @@ -1,135 +0,0 @@ -package dbs - -import ( - "database/sql" - "strings" - - "github.com/hashicorp/vault/helper/strutil" -) - -const defaultMysqlRevocationStmts = ` - REVOKE ALL PRIVILEGES, GRANT OPTION FROM '{{name}}'@'%'; - DROP USER '{{name}}'@'%' -` - -type MySQL struct { - ConnectionProducer - CredentialsProducer -} - -func (m *MySQL) Type() string { - return mySQLTypeName -} - -func (m *MySQL) getConnection() (*sql.DB, error) { - db, err := m.connection() - if err != nil { - return nil, err - } - - return db.(*sql.DB), nil -} - -func (m *MySQL) CreateUser(statements Statements, username, password, expiration string) error { - // Grab the lock - m.Lock() - defer m.Unlock() - - // Get the connection - db, err := m.getConnection() - if err != nil { - return err - } - - if statements.CreationStatements == "" { - return ErrEmptyCreationStatement - } - - // Start a transaction - tx, err := db.Begin() - if err != nil { - return err - } - defer tx.Rollback() - - // Execute each query - for _, query := range strutil.ParseArbitraryStringSlice(statements.CreationStatements, ";") { - query = strings.TrimSpace(query) - if len(query) == 0 { - continue - } - - stmt, err := tx.Prepare(queryHelper(query, map[string]string{ - "name": username, - "password": password, - })) - if err != nil { - return err - } - defer stmt.Close() - if _, err := stmt.Exec(); err != nil { - return err - } - } - - // Commit the transaction - if err := tx.Commit(); err != nil { - return err - } - - return nil -} - -// NOOP -func (m *MySQL) RenewUser(statements Statements, username, expiration string) error { - return nil -} - -func (m *MySQL) RevokeUser(statements Statements, username string) error { - // Grab the read lock - m.Lock() - defer m.Unlock() - - // Get the connection - db, err := m.getConnection() - if err != nil { - return err - } - - revocationStmts := statements.RevocationStatements - // Use a default SQL statement for revocation if one cannot be fetched from the role - if revocationStmts == "" { - revocationStmts = defaultMysqlRevocationStmts - } - - // Start a transaction - tx, err := db.Begin() - if err != nil { - return err - } - defer tx.Rollback() - - for _, query := range strutil.ParseArbitraryStringSlice(revocationStmts, ";") { - query = strings.TrimSpace(query) - if len(query) == 0 { - continue - } - - // This is not a prepared statement because not all commands are supported - // 1295: This command is not supported in the prepared statement protocol yet - // Reference https://mariadb.com/kb/en/mariadb/prepare-statement/ - query = strings.Replace(query, "{{name}}", username, -1) - _, err = tx.Exec(query) - if err != nil { - return err - } - - } - - // Commit the transaction - if err := tx.Commit(); err != nil { - return err - } - - return nil -} diff --git a/builtin/logical/database/dbs/mysql_test.go b/builtin/logical/database/dbs/mysql_test.go deleted file mode 100644 index 553acc8ffd52..000000000000 --- a/builtin/logical/database/dbs/mysql_test.go +++ /dev/null @@ -1,346 +0,0 @@ -package dbs - -import ( - "database/sql" - "os" - "sync" - "testing" - "time" - - log "github.com/mgutz/logxi/v1" - dockertest "gopkg.in/ory-am/dockertest.v2" -) - -var ( - testMySQLImagePull sync.Once -) - -func prepareMySQLTestContainer(t *testing.T) (cid dockertest.ContainerID, retURL string) { - if os.Getenv("MYSQL_URL") != "" { - return "", os.Getenv("MYSQL_URL") - } - - // Without this the checks for whether the container has started seem to - // never actually pass. There's really no reason to expose the test - // containers, so don't. - dockertest.BindDockerToLocalhost = "yep" - - testMySQLImagePull.Do(func() { - dockertest.Pull("mysql") - }) - - cid, connErr := dockertest.ConnectToMySQL(60, 500*time.Millisecond, func(connURL string) bool { - // This will cause a validation to run - connProducer := &sqlConnectionProducer{} - connProducer.ConnectionURL = connURL - connProducer.config = &DatabaseConfig{ - DatabaseType: mySQLTypeName, - } - - conn, err := connProducer.connection() - if err != nil { - return false - } - if err := conn.(*sql.DB).Ping(); err != nil { - return false - } - - connProducer.Close() - - retURL = connURL - return true - }) - - if connErr != nil { - t.Fatalf("could not connect to database: %v", connErr) - } - - return -} - -func TestMySQL_Initialize(t *testing.T) { - cid, connURL := prepareMySQLTestContainer(t) - if cid != "" { - defer cleanupTestContainer(t, cid) - } - - conf := &DatabaseConfig{ - DatabaseType: mySQLTypeName, - ConnectionDetails: map[string]interface{}{ - "connection_url": connURL, - }, - } - - dbRaw, err := BuiltinFactory(conf, nil, &log.NullLogger{}) - if err != nil { - t.Fatalf("err: %s", err) - } - - // Deconsturct the middleware chain to get the underlying mysql object - dbTracer := dbRaw.(*databaseTracingMiddleware) - dbMetrics := dbTracer.next.(*databaseMetricsMiddleware) - db := dbMetrics.next.(*MySQL) - connProducer := db.ConnectionProducer.(*sqlConnectionProducer) - - err = dbRaw.Initialize(conf.ConnectionDetails) - if err != nil { - t.Fatalf("err: %s", err) - } - - if !connProducer.initalized { - t.Fatal("Database should be initalized") - } - - err = dbRaw.Close() - if err != nil { - t.Fatalf("err: %s", err) - } - - if connProducer.db != nil { - t.Fatal("db object should be nil") - } -} - -func TestMySQL_CreateUser(t *testing.T) { - cid, connURL := prepareMySQLTestContainer(t) - if cid != "" { - defer cleanupTestContainer(t, cid) - } - - conf := &DatabaseConfig{ - DatabaseType: mySQLTypeName, - ConnectionDetails: map[string]interface{}{ - "connection_url": connURL, - }, - } - - db, err := BuiltinFactory(conf, nil, &log.NullLogger{}) - if err != nil { - t.Fatalf("err: %s", err) - } - err = db.Initialize(conf.ConnectionDetails) - if err != nil { - t.Fatalf("err: %s", err) - } - - username, err := db.GenerateUsername("test") - if err != nil { - t.Fatalf("err: %s", err) - } - - password, err := db.GeneratePassword() - if err != nil { - t.Fatalf("err: %s", err) - } - - expiration, err := db.GenerateExpiration(time.Minute) - if err != nil { - t.Fatalf("err: %s", err) - } - - // Test with no configured Creation Statememt - err = db.CreateUser(Statements{}, username, password, expiration) - if err == nil { - t.Fatal("Expected error when no creation statement is provided") - } - - statements := Statements{ - CreationStatements: testMySQLRoleWildCard, - } - - err = db.CreateUser(statements, username, password, expiration) - if err != nil { - t.Fatalf("err: %s", err) - } - - username, err = db.GenerateUsername("test") - if err != nil { - t.Fatalf("err: %s", err) - } - - password, err = db.GeneratePassword() - if err != nil { - t.Fatalf("err: %s", err) - } - - expiration, err = db.GenerateExpiration(time.Minute) - if err != nil { - t.Fatalf("err: %s", err) - } - statements.CreationStatements = testMySQLRoleHost - err = db.CreateUser(statements, username, password, expiration) - if err != nil { - t.Fatalf("err: %s", err) - } - - username, err = db.GenerateUsername("test") - if err != nil { - t.Fatalf("err: %s", err) - } - - password, err = db.GeneratePassword() - if err != nil { - t.Fatalf("err: %s", err) - } - - expiration, err = db.GenerateExpiration(time.Minute) - if err != nil { - t.Fatalf("err: %s", err) - } -} - -func TestMySQL_RenewUser(t *testing.T) { - cid, connURL := prepareMySQLTestContainer(t) - if cid != "" { - defer cleanupTestContainer(t, cid) - } - - conf := &DatabaseConfig{ - DatabaseType: mySQLTypeName, - ConnectionDetails: map[string]interface{}{ - "connection_url": connURL, - }, - } - - db, err := BuiltinFactory(conf, nil, &log.NullLogger{}) - if err != nil { - t.Fatalf("err: %s", err) - } - err = db.Initialize(conf.ConnectionDetails) - if err != nil { - t.Fatalf("err: %s", err) - } - - username, err := db.GenerateUsername("test") - if err != nil { - t.Fatalf("err: %s", err) - } - - password, err := db.GeneratePassword() - if err != nil { - t.Fatalf("err: %s", err) - } - - expiration, err := db.GenerateExpiration(time.Minute) - if err != nil { - t.Fatalf("err: %s", err) - } - - statements := Statements{ - CreationStatements: testMySQLRoleWildCard, - } - - err = db.CreateUser(statements, username, password, expiration) - if err != nil { - t.Fatalf("err: %s", err) - } - - expiration, err = db.GenerateExpiration(time.Minute) - if err != nil { - t.Fatalf("err: %s", err) - } - err = db.RenewUser(statements, username, expiration) - if err != nil { - t.Fatalf("err: %s", err) - } -} - -func TestMySQL_RevokeUser(t *testing.T) { - cid, connURL := prepareMySQLTestContainer(t) - if cid != "" { - defer cleanupTestContainer(t, cid) - } - - conf := &DatabaseConfig{ - DatabaseType: mySQLTypeName, - ConnectionDetails: map[string]interface{}{ - "connection_url": connURL, - }, - } - - db, err := BuiltinFactory(conf, nil, &log.NullLogger{}) - if err != nil { - t.Fatalf("err: %s", err) - } - err = db.Initialize(conf.ConnectionDetails) - if err != nil { - t.Fatalf("err: %s", err) - } - - username, err := db.GenerateUsername("test") - if err != nil { - t.Fatalf("err: %s", err) - } - - password, err := db.GeneratePassword() - if err != nil { - t.Fatalf("err: %s", err) - } - - expiration, err := db.GenerateExpiration(time.Minute) - if err != nil { - t.Fatalf("err: %s", err) - } - - statements := Statements{ - CreationStatements: testMySQLRoleWildCard, - } - - err = db.CreateUser(statements, username, password, expiration) - if err != nil { - t.Fatalf("err: %s", err) - } - - expiration, err = db.GenerateExpiration(time.Minute) - if err != nil { - t.Fatalf("err: %s", err) - } - - // Test default revoke statememts - err = db.RevokeUser(statements, username) - if err != nil { - t.Fatalf("err: %s", err) - } - - username, err = db.GenerateUsername("test") - if err != nil { - t.Fatalf("err: %s", err) - } - - password, err = db.GeneratePassword() - if err != nil { - t.Fatalf("err: %s", err) - } - - expiration, err = db.GenerateExpiration(time.Minute) - if err != nil { - t.Fatalf("err: %s", err) - } - - statements.CreationStatements = testMySQLRoleHost - err = db.CreateUser(statements, username, password, expiration) - if err != nil { - t.Fatalf("err: %s", err) - } - - // Test custom revoke statements - statements.RevocationStatements = testMySQLRevocationSQL - err = db.RevokeUser(statements, username) - if err != nil { - t.Fatalf("err: %s", err) - } - -} - -const testMySQLRoleWildCard = ` -CREATE USER '{{name}}'@'%' IDENTIFIED BY '{{password}}'; -GRANT SELECT ON *.* TO '{{name}}'@'%'; -` -const testMySQLRoleHost = ` -CREATE USER '{{name}}'@'10.1.1.2' IDENTIFIED BY '{{password}}'; -GRANT SELECT ON *.* TO '{{name}}'@'10.1.1.2'; -` -const testMySQLRevocationSQL = ` -REVOKE ALL PRIVILEGES, GRANT OPTION FROM '{{name}}'@'10.1.1.2'; -DROP USER '{{name}}'@'10.1.1.2'; -` diff --git a/builtin/logical/database/dbs/postgresql.go b/builtin/logical/database/dbs/postgresql.go deleted file mode 100644 index c8ba110cf7a8..000000000000 --- a/builtin/logical/database/dbs/postgresql.go +++ /dev/null @@ -1,279 +0,0 @@ -package dbs - -import ( - "database/sql" - "fmt" - "strings" - - "github.com/hashicorp/vault/helper/strutil" - "github.com/lib/pq" -) - -type PostgreSQL struct { - ConnectionProducer - CredentialsProducer -} - -func (p *PostgreSQL) Type() string { - return postgreSQLTypeName -} - -func (p *PostgreSQL) getConnection() (*sql.DB, error) { - db, err := p.connection() - if err != nil { - return nil, err - } - - return db.(*sql.DB), nil -} - -func (p *PostgreSQL) CreateUser(statements Statements, username, password, expiration string) error { - if statements.CreationStatements == "" { - return ErrEmptyCreationStatement - } - - // Grab the lock - p.Lock() - defer p.Unlock() - - // Get the connection - db, err := p.getConnection() - if err != nil { - return err - } - - // Start a transaction - tx, err := db.Begin() - if err != nil { - return err - } - defer func() { - tx.Rollback() - }() - // Return the secret - - // Execute each query - for _, query := range strutil.ParseArbitraryStringSlice(statements.CreationStatements, ";") { - query = strings.TrimSpace(query) - if len(query) == 0 { - continue - } - - stmt, err := tx.Prepare(queryHelper(query, map[string]string{ - "name": username, - "password": password, - "expiration": expiration, - })) - if err != nil { - return err - } - defer stmt.Close() - if _, err := stmt.Exec(); err != nil { - return err - } - } - - // Commit the transaction - if err := tx.Commit(); err != nil { - return err - } - - return nil -} - -func (p *PostgreSQL) RenewUser(statements Statements, username, expiration string) error { - // Grab the lock - p.Lock() - defer p.Unlock() - - db, err := p.getConnection() - if err != nil { - return err - } - - query := fmt.Sprintf( - "ALTER ROLE %s VALID UNTIL '%s';", - pq.QuoteIdentifier(username), - expiration) - - stmt, err := db.Prepare(query) - if err != nil { - return err - } - defer stmt.Close() - if _, err := stmt.Exec(); err != nil { - return err - } - - return nil -} - -func (p *PostgreSQL) RevokeUser(statements Statements, username string) error { - // Grab the lock - p.Lock() - defer p.Unlock() - - if statements.RevocationStatements == "" { - return p.defaultRevokeUser(username) - } - - return p.customRevokeUser(username, statements.RevocationStatements) -} - -func (p *PostgreSQL) customRevokeUser(username, revocationStmts string) error { - db, err := p.getConnection() - if err != nil { - return err - } - - tx, err := db.Begin() - if err != nil { - return err - } - defer func() { - tx.Rollback() - }() - - for _, query := range strutil.ParseArbitraryStringSlice(revocationStmts, ";") { - query = strings.TrimSpace(query) - if len(query) == 0 { - continue - } - - stmt, err := tx.Prepare(queryHelper(query, map[string]string{ - "name": username, - })) - if err != nil { - return err - } - defer stmt.Close() - - if _, err := stmt.Exec(); err != nil { - return err - } - } - - if err := tx.Commit(); err != nil { - return err - } - - return nil -} - -func (p *PostgreSQL) defaultRevokeUser(username string) error { - db, err := p.getConnection() - if err != nil { - return err - } - - // Check if the role exists - var exists bool - err = db.QueryRow("SELECT exists (SELECT rolname FROM pg_roles WHERE rolname=$1);", username).Scan(&exists) - if err != nil && err != sql.ErrNoRows { - return err - } - - if exists == false { - return nil - } - - // Query for permissions; we need to revoke permissions before we can drop - // the role - // This isn't done in a transaction because even if we fail along the way, - // we want to remove as much access as possible - stmt, err := db.Prepare("SELECT DISTINCT table_schema FROM information_schema.role_column_grants WHERE grantee=$1;") - if err != nil { - return err - } - defer stmt.Close() - - rows, err := stmt.Query(username) - if err != nil { - return err - } - defer rows.Close() - - const initialNumRevocations = 16 - revocationStmts := make([]string, 0, initialNumRevocations) - for rows.Next() { - var schema string - err = rows.Scan(&schema) - if err != nil { - // keep going; remove as many permissions as possible right now - continue - } - revocationStmts = append(revocationStmts, fmt.Sprintf( - `REVOKE ALL PRIVILEGES ON ALL TABLES IN SCHEMA %s FROM %s;`, - pq.QuoteIdentifier(schema), - pq.QuoteIdentifier(username))) - - revocationStmts = append(revocationStmts, fmt.Sprintf( - `REVOKE USAGE ON SCHEMA %s FROM %s;`, - pq.QuoteIdentifier(schema), - pq.QuoteIdentifier(username))) - } - - // for good measure, revoke all privileges and usage on schema public - revocationStmts = append(revocationStmts, fmt.Sprintf( - `REVOKE ALL PRIVILEGES ON ALL TABLES IN SCHEMA public FROM %s;`, - pq.QuoteIdentifier(username))) - - revocationStmts = append(revocationStmts, fmt.Sprintf( - "REVOKE ALL PRIVILEGES ON ALL SEQUENCES IN SCHEMA public FROM %s;", - pq.QuoteIdentifier(username))) - - revocationStmts = append(revocationStmts, fmt.Sprintf( - "REVOKE USAGE ON SCHEMA public FROM %s;", - pq.QuoteIdentifier(username))) - - // get the current database name so we can issue a REVOKE CONNECT for - // this username - var dbname sql.NullString - if err := db.QueryRow("SELECT current_database();").Scan(&dbname); err != nil { - return err - } - - if dbname.Valid { - revocationStmts = append(revocationStmts, fmt.Sprintf( - `REVOKE CONNECT ON DATABASE %s FROM %s;`, - pq.QuoteIdentifier(dbname.String), - pq.QuoteIdentifier(username))) - } - - // again, here, we do not stop on error, as we want to remove as - // many permissions as possible right now - var lastStmtError error - for _, query := range revocationStmts { - stmt, err := db.Prepare(query) - if err != nil { - lastStmtError = err - continue - } - defer stmt.Close() - _, err = stmt.Exec() - if err != nil { - lastStmtError = err - } - } - - // can't drop if not all privileges are revoked - if rows.Err() != nil { - return fmt.Errorf("could not generate revocation statements for all rows: %s", rows.Err()) - } - if lastStmtError != nil { - return fmt.Errorf("could not perform all revocation statements: %s", lastStmtError) - } - - // Drop this user - stmt, err = db.Prepare(fmt.Sprintf( - `DROP ROLE IF EXISTS %s;`, pq.QuoteIdentifier(username))) - if err != nil { - return err - } - defer stmt.Close() - if _, err := stmt.Exec(); err != nil { - return err - } - - return nil -} diff --git a/builtin/logical/database/dbs/postgresql_test.go b/builtin/logical/database/dbs/postgresql_test.go deleted file mode 100644 index 83aed50ba91b..000000000000 --- a/builtin/logical/database/dbs/postgresql_test.go +++ /dev/null @@ -1,414 +0,0 @@ -package dbs - -import ( - "database/sql" - "os" - "sync" - "testing" - "time" - - log "github.com/mgutz/logxi/v1" - dockertest "gopkg.in/ory-am/dockertest.v2" -) - -var ( - testPostgresImagePull sync.Once -) - -func preparePostgresTestContainer(t *testing.T) (cid dockertest.ContainerID, retURL string) { - if os.Getenv("PG_URL") != "" { - return "", os.Getenv("PG_URL") - } - - // Without this the checks for whether the container has started seem to - // never actually pass. There's really no reason to expose the test - // containers, so don't. - dockertest.BindDockerToLocalhost = "yep" - - testPostgresImagePull.Do(func() { - dockertest.Pull("postgres") - }) - - cid, connErr := dockertest.ConnectToPostgreSQL(60, 500*time.Millisecond, func(connURL string) bool { - // This will cause a validation to run - connProducer := &sqlConnectionProducer{} - connProducer.ConnectionURL = connURL - connProducer.config = &DatabaseConfig{ - DatabaseType: postgreSQLTypeName, - } - - conn, err := connProducer.connection() - if err != nil { - return false - } - if err := conn.(*sql.DB).Ping(); err != nil { - return false - } - - connProducer.Close() - - retURL = connURL - return true - }) - - if connErr != nil { - t.Fatalf("could not connect to database: %v", connErr) - } - - return -} - -func cleanupTestContainer(t *testing.T, cid dockertest.ContainerID) { - err := cid.KillRemove() - if err != nil { - t.Fatal(err) - } -} - -func TestPostgreSQL_Initialize(t *testing.T) { - cid, connURL := preparePostgresTestContainer(t) - if cid != "" { - defer cleanupTestContainer(t, cid) - } - - conf := &DatabaseConfig{ - DatabaseType: postgreSQLTypeName, - ConnectionDetails: map[string]interface{}{ - "connection_url": connURL, - }, - } - - dbRaw, err := BuiltinFactory(conf, nil, &log.NullLogger{}) - if err != nil { - t.Fatalf("err: %s", err) - } - - // Deconsturct the middleware chain to get the underlying postgres object - dbTracer := dbRaw.(*databaseTracingMiddleware) - dbMetrics := dbTracer.next.(*databaseMetricsMiddleware) - db := dbMetrics.next.(*PostgreSQL) - connProducer := db.ConnectionProducer.(*sqlConnectionProducer) - - err = dbRaw.Initialize(conf.ConnectionDetails) - if err != nil { - t.Fatalf("err: %s", err) - } - - if !connProducer.initalized { - t.Fatal("Database should be initalized") - } - - err = dbRaw.Close() - if err != nil { - t.Fatalf("err: %s", err) - } - - if connProducer.db != nil { - t.Fatal("db object should be nil") - } -} - -func TestPostgreSQL_CreateUser(t *testing.T) { - cid, connURL := preparePostgresTestContainer(t) - if cid != "" { - defer cleanupTestContainer(t, cid) - } - - conf := &DatabaseConfig{ - DatabaseType: postgreSQLTypeName, - ConnectionDetails: map[string]interface{}{ - "connection_url": connURL, - }, - } - - db, err := BuiltinFactory(conf, nil, &log.NullLogger{}) - if err != nil { - t.Fatalf("err: %s", err) - } - err = db.Initialize(conf.ConnectionDetails) - if err != nil { - t.Fatalf("err: %s", err) - } - - username, err := db.GenerateUsername("test") - if err != nil { - t.Fatalf("err: %s", err) - } - - password, err := db.GeneratePassword() - if err != nil { - t.Fatalf("err: %s", err) - } - - expiration, err := db.GenerateExpiration(time.Minute) - if err != nil { - t.Fatalf("err: %s", err) - } - - // Test with no configured Creation Statememt - err = db.CreateUser(Statements{}, username, password, expiration) - if err == nil { - t.Fatal("Expected error when no creation statement is provided") - } - - statements := Statements{ - CreationStatements: testPostgresRole, - } - - err = db.CreateUser(statements, username, password, expiration) - if err != nil { - t.Fatalf("err: %s", err) - } - - username, err = db.GenerateUsername("test") - if err != nil { - t.Fatalf("err: %s", err) - } - - password, err = db.GeneratePassword() - if err != nil { - t.Fatalf("err: %s", err) - } - - expiration, err = db.GenerateExpiration(time.Minute) - if err != nil { - t.Fatalf("err: %s", err) - } - statements.CreationStatements = testPostgresReadOnlyRole - err = db.CreateUser(statements, username, password, expiration) - if err != nil { - t.Fatalf("err: %s", err) - } - - username, err = db.GenerateUsername("test") - if err != nil { - t.Fatalf("err: %s", err) - } - - password, err = db.GeneratePassword() - if err != nil { - t.Fatalf("err: %s", err) - } - - expiration, err = db.GenerateExpiration(time.Minute) - if err != nil { - t.Fatalf("err: %s", err) - } - /* statements.CreationStatements = testBlockStatementRole - err = db.CreateUser(statements, username, password, expiration) - if err != nil { - t.Fatalf("err: %s", err) - }*/ -} - -func TestPostgreSQL_RenewUser(t *testing.T) { - cid, connURL := preparePostgresTestContainer(t) - if cid != "" { - defer cleanupTestContainer(t, cid) - } - - conf := &DatabaseConfig{ - DatabaseType: postgreSQLTypeName, - ConnectionDetails: map[string]interface{}{ - "connection_url": connURL, - }, - } - - db, err := BuiltinFactory(conf, nil, &log.NullLogger{}) - if err != nil { - t.Fatalf("err: %s", err) - } - err = db.Initialize(conf.ConnectionDetails) - if err != nil { - t.Fatalf("err: %s", err) - } - - username, err := db.GenerateUsername("test") - if err != nil { - t.Fatalf("err: %s", err) - } - - password, err := db.GeneratePassword() - if err != nil { - t.Fatalf("err: %s", err) - } - - expiration, err := db.GenerateExpiration(time.Minute) - if err != nil { - t.Fatalf("err: %s", err) - } - - statements := Statements{ - CreationStatements: testPostgresRole, - } - - err = db.CreateUser(statements, username, password, expiration) - if err != nil { - t.Fatalf("err: %s", err) - } - - expiration, err = db.GenerateExpiration(time.Minute) - if err != nil { - t.Fatalf("err: %s", err) - } - err = db.RenewUser(statements, username, expiration) - if err != nil { - t.Fatalf("err: %s", err) - } -} - -func TestPostgreSQL_RevokeUser(t *testing.T) { - cid, connURL := preparePostgresTestContainer(t) - if cid != "" { - defer cleanupTestContainer(t, cid) - } - - conf := &DatabaseConfig{ - DatabaseType: postgreSQLTypeName, - ConnectionDetails: map[string]interface{}{ - "connection_url": connURL, - }, - } - - db, err := BuiltinFactory(conf, nil, &log.NullLogger{}) - if err != nil { - t.Fatalf("err: %s", err) - } - err = db.Initialize(conf.ConnectionDetails) - if err != nil { - t.Fatalf("err: %s", err) - } - - username, err := db.GenerateUsername("test") - if err != nil { - t.Fatalf("err: %s", err) - } - - password, err := db.GeneratePassword() - if err != nil { - t.Fatalf("err: %s", err) - } - - expiration, err := db.GenerateExpiration(time.Minute) - if err != nil { - t.Fatalf("err: %s", err) - } - - statements := Statements{ - CreationStatements: testPostgresRole, - } - - err = db.CreateUser(statements, username, password, expiration) - if err != nil { - t.Fatalf("err: %s", err) - } - - expiration, err = db.GenerateExpiration(time.Minute) - if err != nil { - t.Fatalf("err: %s", err) - } - - // Test default revoke statememts - err = db.RevokeUser(statements, username) - if err != nil { - t.Fatalf("err: %s", err) - } - - username, err = db.GenerateUsername("test") - if err != nil { - t.Fatalf("err: %s", err) - } - - password, err = db.GeneratePassword() - if err != nil { - t.Fatalf("err: %s", err) - } - - expiration, err = db.GenerateExpiration(time.Minute) - if err != nil { - t.Fatalf("err: %s", err) - } - - err = db.CreateUser(statements, username, password, expiration) - if err != nil { - t.Fatalf("err: %s", err) - } - - // Test custom revoke statements - statements.RevocationStatements = defaultPostgresRevocationSQL - err = db.RevokeUser(statements, username) - if err != nil { - t.Fatalf("err: %s", err) - } - -} - -const testPostgresRole = ` -CREATE ROLE "{{name}}" WITH - LOGIN - PASSWORD '{{password}}' - VALID UNTIL '{{expiration}}'; -GRANT ALL PRIVILEGES ON ALL TABLES IN SCHEMA public TO "{{name}}"; -` - -const testPostgresReadOnlyRole = ` -CREATE ROLE "{{name}}" WITH - LOGIN - PASSWORD '{{password}}' - VALID UNTIL '{{expiration}}'; -GRANT SELECT ON ALL TABLES IN SCHEMA public TO "{{name}}"; -GRANT SELECT ON ALL SEQUENCES IN SCHEMA public TO "{{name}}"; -` - -const testPostgresBlockStatementRole = ` -DO $$ -BEGIN - IF NOT EXISTS (SELECT * FROM pg_catalog.pg_roles WHERE rolname='foo-role') THEN - CREATE ROLE "foo-role"; - CREATE SCHEMA IF NOT EXISTS foo AUTHORIZATION "foo-role"; - ALTER ROLE "foo-role" SET search_path = foo; - GRANT TEMPORARY ON DATABASE "postgres" TO "foo-role"; - GRANT ALL PRIVILEGES ON SCHEMA foo TO "foo-role"; - GRANT ALL PRIVILEGES ON ALL TABLES IN SCHEMA foo TO "foo-role"; - GRANT ALL PRIVILEGES ON ALL SEQUENCES IN SCHEMA foo TO "foo-role"; - GRANT ALL PRIVILEGES ON ALL FUNCTIONS IN SCHEMA foo TO "foo-role"; - END IF; -END -$$ - -CREATE ROLE "{{name}}" WITH LOGIN PASSWORD '{{password}}' VALID UNTIL '{{expiration}}'; -GRANT "foo-role" TO "{{name}}"; -ALTER ROLE "{{name}}" SET search_path = foo; -GRANT CONNECT ON DATABASE "postgres" TO "{{name}}"; -` - -var testPostgresBlockStatementRoleSlice = []string{ - ` -DO $$ -BEGIN - IF NOT EXISTS (SELECT * FROM pg_catalog.pg_roles WHERE rolname='foo-role') THEN - CREATE ROLE "foo-role"; - CREATE SCHEMA IF NOT EXISTS foo AUTHORIZATION "foo-role"; - ALTER ROLE "foo-role" SET search_path = foo; - GRANT TEMPORARY ON DATABASE "postgres" TO "foo-role"; - GRANT ALL PRIVILEGES ON SCHEMA foo TO "foo-role"; - GRANT ALL PRIVILEGES ON ALL TABLES IN SCHEMA foo TO "foo-role"; - GRANT ALL PRIVILEGES ON ALL SEQUENCES IN SCHEMA foo TO "foo-role"; - GRANT ALL PRIVILEGES ON ALL FUNCTIONS IN SCHEMA foo TO "foo-role"; - END IF; -END -$$ -`, - `CREATE ROLE "{{name}}" WITH LOGIN PASSWORD '{{password}}' VALID UNTIL '{{expiration}}';`, - `GRANT "foo-role" TO "{{name}}";`, - `ALTER ROLE "{{name}}" SET search_path = foo;`, - `GRANT CONNECT ON DATABASE "postgres" TO "{{name}}";`, -} - -const defaultPostgresRevocationSQL = ` -REVOKE ALL PRIVILEGES ON ALL TABLES IN SCHEMA public FROM "{{name}}"; -REVOKE ALL PRIVILEGES ON ALL SEQUENCES IN SCHEMA public FROM "{{name}}"; -REVOKE USAGE ON SCHEMA public FROM "{{name}}"; - -DROP ROLE IF EXISTS "{{name}}"; -` diff --git a/builtin/logical/database/path_config_connection.go b/builtin/logical/database/path_config_connection.go index be2038c31cab..48d9b88803f6 100644 --- a/builtin/logical/database/path_config_connection.go +++ b/builtin/logical/database/path_config_connection.go @@ -3,10 +3,8 @@ package database import ( "fmt" "strings" - "time" "github.com/fatih/structs" - "github.com/hashicorp/vault/builtin/logical/database/dbs" "github.com/hashicorp/vault/logical" "github.com/hashicorp/vault/logical/framework" ) @@ -50,16 +48,10 @@ func (b *databaseBackend) pathConnectionReset(req *logical.Request, data *framew return nil, nil } -// pathConfigureBuiltinConnection returns a configured framework.Path setup to -// operate on builtin databases. -func pathConfigureBuiltinConnection(b *databaseBackend) *framework.Path { - return buildConfigConnectionPath("dbs/%s", b.connectionWriteHandler(dbs.BuiltinFactory), b.connectionReadHandler(), b.connectionDeleteHandler()) -} - // pathConfigurePluginConnection returns a configured framework.Path setup to // operate on plugins. func pathConfigurePluginConnection(b *databaseBackend) *framework.Path { - return buildConfigConnectionPath("dbs/plugin/%s", b.connectionWriteHandler(dbs.PluginFactory), b.connectionReadHandler(), b.connectionDeleteHandler()) + return buildConfigConnectionPath("config/%s", b.connectionWriteHandler(), b.connectionReadHandler(), b.connectionDeleteHandler()) } // buildConfigConnectionPath reutns a configured framework.Path using the passed @@ -74,40 +66,12 @@ func buildConfigConnectionPath(path string, updateOp, readOp, deleteOp framework Description: "Name of this DB type", }, - "connection_type": &framework.FieldSchema{ - Type: framework.TypeString, - Description: "DB type (e.g. postgres)", - }, - "verify_connection": &framework.FieldSchema{ Type: framework.TypeBool, Default: true, Description: `If set, connection_url is verified by actually connecting to the database`, }, - "max_open_connections": &framework.FieldSchema{ - Type: framework.TypeInt, - Description: `Maximum number of open connections to the database; -a zero uses the default value of two and a -negative value means unlimited`, - }, - - "max_idle_connections": &framework.FieldSchema{ - Type: framework.TypeInt, - Description: `Maximum number of idle connections to the database; -a zero uses the value of max_open_connections -and a negative value disables idle connections. -If larger than max_open_connections it will be -reduced to the same size.`, - }, - - "max_connection_lifetime": &framework.FieldSchema{ - Type: framework.TypeString, - Default: "0s", - Description: `Maximum amount of time a connection may be reused; - a zero or negative value reuses connections forever.`, - }, - "plugin_name": &framework.FieldSchema{ Type: framework.TypeString, Description: `Maximum amount of time a connection may be reused; @@ -139,7 +103,7 @@ func (b *databaseBackend) connectionReadHandler() framework.OperationFunc { return nil, nil } - var config dbs.DatabaseConfig + var config DatabaseConfig if err := entry.DecodeJSON(&config); err != nil { return nil, err } @@ -180,40 +144,12 @@ func (b *databaseBackend) connectionDeleteHandler() framework.OperationFunc { // connectionWriteHandler returns a handler function for creating and updating // both builtin and plugin database types. -func (b *databaseBackend) connectionWriteHandler(factory dbs.Factory) framework.OperationFunc { +func (b *databaseBackend) connectionWriteHandler() framework.OperationFunc { return func(req *logical.Request, data *framework.FieldData) (*logical.Response, error) { - connType := data.Get("connection_type").(string) - if connType == "" { - return logical.ErrorResponse("connection_type not set"), nil - } - - maxOpenConns := data.Get("max_open_connections").(int) - if maxOpenConns == 0 { - maxOpenConns = 2 - } - - maxIdleConns := data.Get("max_idle_connections").(int) - if maxIdleConns == 0 { - maxIdleConns = maxOpenConns - } - if maxIdleConns > maxOpenConns { - maxIdleConns = maxOpenConns - } - - maxConnLifetimeRaw := data.Get("max_connection_lifetime").(string) - maxConnLifetime, err := time.ParseDuration(maxConnLifetimeRaw) - if err != nil { - return logical.ErrorResponse(fmt.Sprintf( - "Invalid max_connection_lifetime: %s", err)), nil - } - config := &dbs.DatabaseConfig{ - DatabaseType: connType, - ConnectionDetails: data.Raw, - MaxOpenConnections: maxOpenConns, - MaxIdleConnections: maxIdleConns, - MaxConnectionLifetime: maxConnLifetime, - PluginName: data.Get("plugin_name").(string), + config := &DatabaseConfig{ + ConnectionDetails: data.Raw, + PluginName: data.Get("plugin_name").(string), } name := data.Get("name").(string) @@ -227,7 +163,7 @@ func (b *databaseBackend) connectionWriteHandler(factory dbs.Factory) framework. b.Lock() defer b.Unlock() - db, err := factory(config, b.System(), b.logger) + db, err := PluginFactory(config, b.System(), b.logger) if err != nil { return logical.ErrorResponse(fmt.Sprintf("Error creating database object: %s", err)), nil } diff --git a/builtin/logical/database/path_roles.go b/builtin/logical/database/path_roles.go index 6f62c79d98a7..d099ef17871e 100644 --- a/builtin/logical/database/path_roles.go +++ b/builtin/logical/database/path_roles.go @@ -4,7 +4,6 @@ import ( "fmt" "time" - "github.com/hashicorp/vault/builtin/logical/database/dbs" "github.com/hashicorp/vault/logical" "github.com/hashicorp/vault/logical/framework" ) @@ -156,7 +155,7 @@ func (b *databaseBackend) pathRoleCreate(req *logical.Request, data *framework.F "Invalid max_ttl: %s", err)), nil } - statements := dbs.Statements{ + statements := Statements{ CreationStatements: creationStmts, RevocationStatements: revocationStmts, RollbackStatements: rollbackStmts, @@ -183,10 +182,10 @@ func (b *databaseBackend) pathRoleCreate(req *logical.Request, data *framework.F } type roleEntry struct { - DBName string `json:"db_name" mapstructure:"db_name" structs:"db_name"` - Statements dbs.Statements `json:"statments" mapstructure:"statements" structs:"statments"` - DefaultTTL time.Duration `json:"default_ttl" mapstructure:"default_ttl" structs:"default_ttl"` - MaxTTL time.Duration `json:"max_ttl" mapstructure:"max_ttl" structs:"max_ttl"` + DBName string `json:"db_name" mapstructure:"db_name" structs:"db_name"` + Statements Statements `json:"statments" mapstructure:"statements" structs:"statments"` + DefaultTTL time.Duration `json:"default_ttl" mapstructure:"default_ttl" structs:"default_ttl"` + MaxTTL time.Duration `json:"max_ttl" mapstructure:"max_ttl" structs:"max_ttl"` } const pathRoleHelpSyn = ` diff --git a/builtin/logical/database/dbs/plugin.go b/builtin/logical/database/plugin.go similarity index 88% rename from builtin/logical/database/dbs/plugin.go rename to builtin/logical/database/plugin.go index 441f97ca0fd8..5a6a8e328598 100644 --- a/builtin/logical/database/dbs/plugin.go +++ b/builtin/logical/database/plugin.go @@ -1,6 +1,7 @@ -package dbs +package database import ( + "errors" "fmt" "net/rpc" "sync" @@ -8,8 +9,47 @@ import ( "github.com/hashicorp/go-plugin" "github.com/hashicorp/vault/helper/pluginutil" + "github.com/hashicorp/vault/logical" + log "github.com/mgutz/logxi/v1" ) +var ( + ErrEmptyPluginName = errors.New("empty plugin name") +) + +// PluginFactory is used to build plugin database types. It wraps the database +// object in a logging and metrics middleware. +func PluginFactory(conf *DatabaseConfig, sys logical.SystemView, logger log.Logger) (DatabaseType, error) { + if conf.PluginName == "" { + return nil, ErrEmptyPluginName + } + + pluginMeta, err := sys.LookupPlugin(conf.PluginName) + if err != nil { + return nil, err + } + + db, err := newPluginClient(sys, pluginMeta) + if err != nil { + return nil, err + } + + // Wrap with metrics middleware + db = &databaseMetricsMiddleware{ + next: db, + typeStr: db.Type(), + } + + // Wrap with tracing middleware + db = &databaseTracingMiddleware{ + next: db, + typeStr: db.Type(), + logger: logger, + } + + return db, nil +} + // handshakeConfigs are used to just do a basic handshake between // a plugin and host. If the handshake fails, a user friendly error is shown. // This prevents users from executing bad plugins or executing a plugin @@ -33,7 +73,7 @@ func (DatabasePlugin) Client(b *plugin.MuxBroker, c *rpc.Client) (interface{}, e } // DatabasePluginClient embeds a databasePluginRPCClient and wraps it's close -// method to also call Close() on the plugin.Client. +// method to also call Kill() on the plugin.Client. type DatabasePluginClient struct { client *plugin.Client sync.Mutex diff --git a/builtin/logical/database/dbs/plugin_test.go b/builtin/logical/database/plugin_test.go similarity index 99% rename from builtin/logical/database/dbs/plugin_test.go rename to builtin/logical/database/plugin_test.go index 60cb6814dd5f..2ec01c9556bb 100644 --- a/builtin/logical/database/dbs/plugin_test.go +++ b/builtin/logical/database/plugin_test.go @@ -1,4 +1,4 @@ -package dbs +package database import ( "crypto/sha256" diff --git a/command/plugin-exec.go b/command/plugin-exec.go index f0d6a8d51a5d..70bc8ae1d4d3 100644 --- a/command/plugin-exec.go +++ b/command/plugin-exec.go @@ -4,7 +4,7 @@ import ( "fmt" "strings" - "github.com/hashicorp/vault/helper/pluginutil" + "github.com/hashicorp/vault/helper/builtinplugins" "github.com/hashicorp/vault/meta" ) @@ -29,7 +29,7 @@ func (c *PluginExec) Run(args []string) int { pluginName := args[0] - runner, ok := pluginutil.BuiltinPlugins[pluginName] + runner, ok := builtinplugins.BuiltinPlugins[pluginName] if !ok { c.Ui.Error(fmt.Sprintf( "No plugin with the name %s found", pluginName)) diff --git a/helper/builtinplugins/builtin.go b/helper/builtinplugins/builtin.go new file mode 100644 index 000000000000..6880640d159e --- /dev/null +++ b/helper/builtinplugins/builtin.go @@ -0,0 +1,8 @@ +package builtinplugins + +import "github.com/hashicorp/vault-plugins/database/mysql" + +var BuiltinPlugins = map[string]func() error{ + "mysql-database-plugin": mysql.Run, + // "postgres-database-plugin": postgres.Run, +} diff --git a/helper/pluginutil/builtin.go b/helper/pluginutil/builtin.go deleted file mode 100644 index 6a464bb8243a..000000000000 --- a/helper/pluginutil/builtin.go +++ /dev/null @@ -1,6 +0,0 @@ -package pluginutil - -var BuiltinPlugins = map[string]func() error{ -// "mysql-database-plugin": mysql.Run, -// "postgres-database-plugin": postgres.Run, -} diff --git a/vault/plugin_catalog.go b/vault/plugin_catalog.go index c6e4e4059bba..b9c15db22a52 100644 --- a/vault/plugin_catalog.go +++ b/vault/plugin_catalog.go @@ -8,6 +8,7 @@ import ( "strings" "sync" + "github.com/hashicorp/vault/helper/builtinplugins" "github.com/hashicorp/vault/helper/jsonutil" "github.com/hashicorp/vault/helper/pluginutil" "github.com/hashicorp/vault/logical" @@ -53,7 +54,7 @@ func (c *PluginCatalog) Get(name string) (*pluginutil.PluginRunner, error) { } // Look for builtin plugins - if _, ok := pluginutil.BuiltinPlugins[name]; !ok { + if _, ok := builtinplugins.BuiltinPlugins[name]; !ok { return nil, fmt.Errorf("no plugin found with name: %s", name) } From 0da69cf29d23553e877823d8484216aa12e16f78 Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Wed, 5 Apr 2017 17:19:29 -0700 Subject: [PATCH 058/152] Add postgres builtin plugin --- helper/builtinplugins/builtin.go | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/helper/builtinplugins/builtin.go b/helper/builtinplugins/builtin.go index 6880640d159e..ceaf10edf975 100644 --- a/helper/builtinplugins/builtin.go +++ b/helper/builtinplugins/builtin.go @@ -1,8 +1,11 @@ package builtinplugins -import "github.com/hashicorp/vault-plugins/database/mysql" +import ( + "github.com/hashicorp/vault-plugins/database/mysql" + "github.com/hashicorp/vault-plugins/database/postgresql" +) var BuiltinPlugins = map[string]func() error{ - "mysql-database-plugin": mysql.Run, - // "postgres-database-plugin": postgres.Run, + "mysql-database-plugin": mysql.Run, + "postgresql-database-plugin": postgresql.Run, } From 8e77bd98d8997e1eb4e55609713ac3889803ed4b Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Thu, 6 Apr 2017 12:20:10 -0700 Subject: [PATCH 059/152] Move plugin code into sub directory --- builtin/logical/database/backend.go | 39 +-- builtin/logical/database/dbplugin/client.go | 148 ++++++++ .../{ => dbplugin}/databasemiddleware.go | 2 +- builtin/logical/database/dbplugin/plugin.go | 126 +++++++ .../database/{ => dbplugin}/plugin_test.go | 2 +- builtin/logical/database/dbplugin/server.go | 90 +++++ .../database/path_config_connection.go | 3 +- builtin/logical/database/path_roles.go | 11 +- builtin/logical/database/plugin.go | 324 ------------------ helper/pluginutil/runner.go | 5 + 10 files changed, 385 insertions(+), 365 deletions(-) create mode 100644 builtin/logical/database/dbplugin/client.go rename builtin/logical/database/{ => dbplugin}/databasemiddleware.go (99%) create mode 100644 builtin/logical/database/dbplugin/plugin.go rename builtin/logical/database/{ => dbplugin}/plugin_test.go (99%) create mode 100644 builtin/logical/database/dbplugin/server.go delete mode 100644 builtin/logical/database/plugin.go diff --git a/builtin/logical/database/backend.go b/builtin/logical/database/backend.go index a2fff4ba863d..baa05a0923f0 100644 --- a/builtin/logical/database/backend.go +++ b/builtin/logical/database/backend.go @@ -4,50 +4,23 @@ import ( "fmt" "strings" "sync" - "time" log "github.com/mgutz/logxi/v1" + "github.com/hashicorp/vault/builtin/logical/database/dbplugin" "github.com/hashicorp/vault/logical" "github.com/hashicorp/vault/logical/framework" ) const databaseConfigPath = "database/dbs/" -// DatabaseType is the interface that all database objects must implement. -type DatabaseType interface { - Type() string - CreateUser(statements Statements, username, password, expiration string) error - RenewUser(statements Statements, username, expiration string) error - RevokeUser(statements Statements, username string) error - - Initialize(map[string]interface{}) error - Close() error - - GenerateUsername(displayName string) (string, error) - GeneratePassword() (string, error) - GenerateExpiration(ttl time.Duration) (string, error) -} - // DatabaseConfig is used by the Factory function to configure a DatabaseType // object. type DatabaseConfig struct { PluginName string `json:"plugin_name" structs:"plugin_name" mapstructure:"plugin_name"` // ConnectionDetails stores the database specific connection settings needed // by each database type. - ConnectionDetails map[string]interface{} `json:"connection_details" structs:"connection_details" mapstructure:"connection_details"` - MaxOpenConnections int `json:"max_open_connections" structs:"max_open_connections" mapstructure:"max_open_connections"` - MaxIdleConnections int `json:"max_idle_connections" structs:"max_idle_connections" mapstructure:"max_idle_connections"` - MaxConnectionLifetime time.Duration `json:"max_connection_lifetime" structs:"max_connection_lifetime" mapstructure:"max_connection_lifetime"` -} - -// Statements set in role creation and passed into the database type's functions. -// TODO: Add a way of setting defaults here. -type Statements struct { - CreationStatements string `json:"creation_statments" mapstructure:"creation_statements" structs:"creation_statments"` - RevocationStatements string `json:"revocation_statements" mapstructure:"revocation_statements" structs:"revocation_statements"` - RollbackStatements string `json:"rollback_statements" mapstructure:"rollback_statements" structs:"rollback_statements"` - RenewStatements string `json:"renew_statements" mapstructure:"renew_statements" structs:"renew_statements"` + ConnectionDetails map[string]interface{} `json:"connection_details" structs:"connection_details" mapstructure:"connection_details"` } func Factory(conf *logical.BackendConfig) (logical.Backend, error) { @@ -83,12 +56,12 @@ func Backend(conf *logical.BackendConfig) *databaseBackend { } b.logger = conf.Logger - b.connections = make(map[string]DatabaseType) + b.connections = make(map[string]dbplugin.DatabaseType) return &b } type databaseBackend struct { - connections map[string]DatabaseType + connections map[string]dbplugin.DatabaseType logger log.Logger *framework.Backend @@ -108,7 +81,7 @@ func (b *databaseBackend) closeAllDBs() { // This function is used to retrieve a database object either from the cached // connection map or by using the database config in storage. The caller of this // function needs to hold the backend's lock. -func (b *databaseBackend) getOrCreateDBObj(s logical.Storage, name string) (DatabaseType, error) { +func (b *databaseBackend) getOrCreateDBObj(s logical.Storage, name string) (dbplugin.DatabaseType, error) { // if the object already is built and cached, return it db, ok := b.connections[name] if ok { @@ -128,7 +101,7 @@ func (b *databaseBackend) getOrCreateDBObj(s logical.Storage, name string) (Data return nil, err } - db, err = PluginFactory(&config, b.System(), b.logger) + db, err = dbplugin.PluginFactory(config.PluginName, b.System(), b.logger) if err != nil { return nil, err } diff --git a/builtin/logical/database/dbplugin/client.go b/builtin/logical/database/dbplugin/client.go new file mode 100644 index 000000000000..db6b3d1fdaa4 --- /dev/null +++ b/builtin/logical/database/dbplugin/client.go @@ -0,0 +1,148 @@ +package dbplugin + +import ( + "fmt" + "net/rpc" + "sync" + "time" + + "github.com/hashicorp/go-plugin" + "github.com/hashicorp/vault/helper/pluginutil" +) + +// DatabasePluginClient embeds a databasePluginRPCClient and wraps it's close +// method to also call Kill() on the plugin.Client. +type DatabasePluginClient struct { + client *plugin.Client + sync.Mutex + + *databasePluginRPCClient +} + +func (dc *DatabasePluginClient) Close() error { + err := dc.databasePluginRPCClient.Close() + dc.client.Kill() + + return err +} + +// newPluginClient returns a databaseRPCClient with a connection to a running +// plugin. The client is wrapped in a DatabasePluginClient object to ensure the +// plugin is killed on call of Close(). +func newPluginClient(sys pluginutil.Wrapper, pluginRunner *pluginutil.PluginRunner) (DatabaseType, error) { + // pluginMap is the map of plugins we can dispense. + var pluginMap = map[string]plugin.Plugin{ + "database": new(DatabasePlugin), + } + + client, err := pluginRunner.Run(sys, pluginMap, handshakeConfig, []string{}) + if err != nil { + return nil, err + } + + // Connect via RPC + rpcClient, err := client.Client() + if err != nil { + return nil, err + } + + // Request the plugin + raw, err := rpcClient.Dispense("database") + if err != nil { + return nil, err + } + + // We should have a Greeter now! This feels like a normal interface + // implementation but is in fact over an RPC connection. + databaseRPC := raw.(*databasePluginRPCClient) + + return &DatabasePluginClient{ + client: client, + databasePluginRPCClient: databaseRPC, + }, nil +} + +// ---- RPC client domain ---- + +// databasePluginRPCClient impliments DatabaseType and is used on the client to +// make RPC calls to a plugin. +type databasePluginRPCClient struct { + client *rpc.Client +} + +func (dr *databasePluginRPCClient) Type() string { + var dbType string + //TODO: catch error + dr.client.Call("Plugin.Type", struct{}{}, &dbType) + + return fmt.Sprintf("plugin-%s", dbType) +} + +func (dr *databasePluginRPCClient) CreateUser(statements Statements, username, password, expiration string) error { + req := CreateUserRequest{ + Statements: statements, + Username: username, + Password: password, + Expiration: expiration, + } + + err := dr.client.Call("Plugin.CreateUser", req, &struct{}{}) + + return err +} + +func (dr *databasePluginRPCClient) RenewUser(statements Statements, username, expiration string) error { + req := RenewUserRequest{ + Statements: statements, + Username: username, + Expiration: expiration, + } + + err := dr.client.Call("Plugin.RenewUser", req, &struct{}{}) + + return err +} + +func (dr *databasePluginRPCClient) RevokeUser(statements Statements, username string) error { + req := RevokeUserRequest{ + Statements: statements, + Username: username, + } + + err := dr.client.Call("Plugin.RevokeUser", req, &struct{}{}) + + return err +} + +func (dr *databasePluginRPCClient) Initialize(conf map[string]interface{}) error { + err := dr.client.Call("Plugin.Initialize", conf, &struct{}{}) + + return err +} + +func (dr *databasePluginRPCClient) Close() error { + err := dr.client.Call("Plugin.Close", struct{}{}, &struct{}{}) + + return err +} + +func (dr *databasePluginRPCClient) GenerateUsername(displayName string) (string, error) { + resp := &GenerateUsernameResponse{} + err := dr.client.Call("Plugin.GenerateUsername", displayName, resp) + + return resp.Username, err +} + +func (dr *databasePluginRPCClient) GeneratePassword() (string, error) { + resp := &GeneratePasswordResponse{} + err := dr.client.Call("Plugin.GeneratePassword", struct{}{}, resp) + + return resp.Password, err +} + +func (dr *databasePluginRPCClient) GenerateExpiration(duration time.Duration) (string, error) { + resp := &GenerateExpirationResponse{} + err := dr.client.Call("Plugin.GenerateExpiration", duration, resp) + + return resp.Expiration, err +} diff --git a/builtin/logical/database/databasemiddleware.go b/builtin/logical/database/dbplugin/databasemiddleware.go similarity index 99% rename from builtin/logical/database/databasemiddleware.go rename to builtin/logical/database/dbplugin/databasemiddleware.go index 5892e8064a37..b4a9809508b0 100644 --- a/builtin/logical/database/databasemiddleware.go +++ b/builtin/logical/database/dbplugin/databasemiddleware.go @@ -1,4 +1,4 @@ -package database +package dbplugin import ( "time" diff --git a/builtin/logical/database/dbplugin/plugin.go b/builtin/logical/database/dbplugin/plugin.go new file mode 100644 index 000000000000..994f3b0ce95c --- /dev/null +++ b/builtin/logical/database/dbplugin/plugin.go @@ -0,0 +1,126 @@ +package dbplugin + +import ( + "errors" + "net/rpc" + "time" + + "github.com/hashicorp/go-plugin" + "github.com/hashicorp/vault/helper/pluginutil" + log "github.com/mgutz/logxi/v1" +) + +var ( + ErrEmptyPluginName = errors.New("empty plugin name") +) + +// DatabaseType is the interface that all database objects must implement. +type DatabaseType interface { + Type() string + CreateUser(statements Statements, username, password, expiration string) error + RenewUser(statements Statements, username, expiration string) error + RevokeUser(statements Statements, username string) error + + Initialize(map[string]interface{}) error + Close() error + + GenerateUsername(displayName string) (string, error) + GeneratePassword() (string, error) + GenerateExpiration(ttl time.Duration) (string, error) +} + +// Statements set in role creation and passed into the database type's functions. +// TODO: Add a way of setting defaults here. +type Statements struct { + CreationStatements string `json:"creation_statments" mapstructure:"creation_statements" structs:"creation_statments"` + RevocationStatements string `json:"revocation_statements" mapstructure:"revocation_statements" structs:"revocation_statements"` + RollbackStatements string `json:"rollback_statements" mapstructure:"rollback_statements" structs:"rollback_statements"` + RenewStatements string `json:"renew_statements" mapstructure:"renew_statements" structs:"renew_statements"` +} + +// PluginFactory is used to build plugin database types. It wraps the database +// object in a logging and metrics middleware. +func PluginFactory(pluginName string, sys pluginutil.LookWrapper, logger log.Logger) (DatabaseType, error) { + if pluginName == "" { + return nil, ErrEmptyPluginName + } + + pluginMeta, err := sys.LookupPlugin(pluginName) + if err != nil { + return nil, err + } + + db, err := newPluginClient(sys, pluginMeta) + if err != nil { + return nil, err + } + + // Wrap with metrics middleware + db = &databaseMetricsMiddleware{ + next: db, + typeStr: db.Type(), + } + + // Wrap with tracing middleware + db = &databaseTracingMiddleware{ + next: db, + typeStr: db.Type(), + logger: logger, + } + + return db, nil +} + +// handshakeConfigs are used to just do a basic handshake between +// a plugin and host. If the handshake fails, a user friendly error is shown. +// This prevents users from executing bad plugins or executing a plugin +// directory. It is a UX feature, not a security feature. +var handshakeConfig = plugin.HandshakeConfig{ + ProtocolVersion: 1, + MagicCookieKey: "VAULT_DATABASE_PLUGIN", + MagicCookieValue: "926a0820-aea2-be28-51d6-83cdf00e8edb", +} + +type DatabasePlugin struct { + impl DatabaseType +} + +func (d DatabasePlugin) Server(*plugin.MuxBroker) (interface{}, error) { + return &databasePluginRPCServer{impl: d.impl}, nil +} + +func (DatabasePlugin) Client(b *plugin.MuxBroker, c *rpc.Client) (interface{}, error) { + return &databasePluginRPCClient{client: c}, nil +} + +// ---- RPC Request Args Domain ---- + +type CreateUserRequest struct { + Statements Statements + Username string + Password string + Expiration string +} + +type RenewUserRequest struct { + Statements Statements + Username string + Expiration string +} + +type RevokeUserRequest struct { + Statements Statements + Username string +} + +// ---- RPC Response Args Domain ---- + +type GenerateUsernameResponse struct { + Username string +} +type GenerateExpirationResponse struct { + Expiration string +} +type GeneratePasswordResponse struct { + Password string +} diff --git a/builtin/logical/database/plugin_test.go b/builtin/logical/database/dbplugin/plugin_test.go similarity index 99% rename from builtin/logical/database/plugin_test.go rename to builtin/logical/database/dbplugin/plugin_test.go index 2ec01c9556bb..849e1ebbf463 100644 --- a/builtin/logical/database/plugin_test.go +++ b/builtin/logical/database/dbplugin/plugin_test.go @@ -1,4 +1,4 @@ -package database +package dbplugin import ( "crypto/sha256" diff --git a/builtin/logical/database/dbplugin/server.go b/builtin/logical/database/dbplugin/server.go new file mode 100644 index 000000000000..018d9b8db1f6 --- /dev/null +++ b/builtin/logical/database/dbplugin/server.go @@ -0,0 +1,90 @@ +package dbplugin + +import ( + "time" + + "github.com/hashicorp/go-plugin" + "github.com/hashicorp/vault/helper/pluginutil" +) + +// NewPluginServer is called from within a plugin and wraps the provided +// DatabaseType implimentation in a databasePluginRPCServer object and starts a +// RPC server. +func NewPluginServer(db DatabaseType) { + dbPlugin := &DatabasePlugin{ + impl: db, + } + + // pluginMap is the map of plugins we can dispense. + var pluginMap = map[string]plugin.Plugin{ + "database": dbPlugin, + } + + plugin.Serve(&plugin.ServeConfig{ + HandshakeConfig: handshakeConfig, + Plugins: pluginMap, + TLSProvider: pluginutil.VaultPluginTLSProvider, + }) +} + +// ---- RPC server domain ---- + +// databasePluginRPCServer impliments DatabaseType and is run inside a plugin +type databasePluginRPCServer struct { + impl DatabaseType +} + +func (ds *databasePluginRPCServer) Type(_ struct{}, resp *string) error { + *resp = ds.impl.Type() + return nil +} + +func (ds *databasePluginRPCServer) CreateUser(args *CreateUserRequest, _ *struct{}) error { + err := ds.impl.CreateUser(args.Statements, args.Username, args.Password, args.Expiration) + + return err +} + +func (ds *databasePluginRPCServer) RenewUser(args *RenewUserRequest, _ *struct{}) error { + err := ds.impl.RenewUser(args.Statements, args.Username, args.Expiration) + + return err +} + +func (ds *databasePluginRPCServer) RevokeUser(args *RevokeUserRequest, _ *struct{}) error { + err := ds.impl.RevokeUser(args.Statements, args.Username) + + return err +} + +func (ds *databasePluginRPCServer) Initialize(args map[string]interface{}, _ *struct{}) error { + err := ds.impl.Initialize(args) + + return err +} + +func (ds *databasePluginRPCServer) Close(_ struct{}, _ *struct{}) error { + ds.impl.Close() + return nil +} + +func (ds *databasePluginRPCServer) GenerateUsername(args string, resp *GenerateUsernameResponse) error { + var err error + resp.Username, err = ds.impl.GenerateUsername(args) + + return err +} + +func (ds *databasePluginRPCServer) GeneratePassword(_ struct{}, resp *GeneratePasswordResponse) error { + var err error + resp.Password, err = ds.impl.GeneratePassword() + + return err +} + +func (ds *databasePluginRPCServer) GenerateExpiration(args time.Duration, resp *GenerateExpirationResponse) error { + var err error + resp.Expiration, err = ds.impl.GenerateExpiration(args) + + return err +} diff --git a/builtin/logical/database/path_config_connection.go b/builtin/logical/database/path_config_connection.go index 48d9b88803f6..4af6e70a0514 100644 --- a/builtin/logical/database/path_config_connection.go +++ b/builtin/logical/database/path_config_connection.go @@ -5,6 +5,7 @@ import ( "strings" "github.com/fatih/structs" + "github.com/hashicorp/vault/builtin/logical/database/dbplugin" "github.com/hashicorp/vault/logical" "github.com/hashicorp/vault/logical/framework" ) @@ -163,7 +164,7 @@ func (b *databaseBackend) connectionWriteHandler() framework.OperationFunc { b.Lock() defer b.Unlock() - db, err := PluginFactory(config, b.System(), b.logger) + db, err := dbplugin.PluginFactory(config.PluginName, b.System(), b.logger) if err != nil { return logical.ErrorResponse(fmt.Sprintf("Error creating database object: %s", err)), nil } diff --git a/builtin/logical/database/path_roles.go b/builtin/logical/database/path_roles.go index d099ef17871e..b3ef6f753db8 100644 --- a/builtin/logical/database/path_roles.go +++ b/builtin/logical/database/path_roles.go @@ -4,6 +4,7 @@ import ( "fmt" "time" + "github.com/hashicorp/vault/builtin/logical/database/dbplugin" "github.com/hashicorp/vault/logical" "github.com/hashicorp/vault/logical/framework" ) @@ -155,7 +156,7 @@ func (b *databaseBackend) pathRoleCreate(req *logical.Request, data *framework.F "Invalid max_ttl: %s", err)), nil } - statements := Statements{ + statements := dbplugin.Statements{ CreationStatements: creationStmts, RevocationStatements: revocationStmts, RollbackStatements: rollbackStmts, @@ -182,10 +183,10 @@ func (b *databaseBackend) pathRoleCreate(req *logical.Request, data *framework.F } type roleEntry struct { - DBName string `json:"db_name" mapstructure:"db_name" structs:"db_name"` - Statements Statements `json:"statments" mapstructure:"statements" structs:"statments"` - DefaultTTL time.Duration `json:"default_ttl" mapstructure:"default_ttl" structs:"default_ttl"` - MaxTTL time.Duration `json:"max_ttl" mapstructure:"max_ttl" structs:"max_ttl"` + DBName string `json:"db_name" mapstructure:"db_name" structs:"db_name"` + Statements dbplugin.Statements `json:"statments" mapstructure:"statements" structs:"statments"` + DefaultTTL time.Duration `json:"default_ttl" mapstructure:"default_ttl" structs:"default_ttl"` + MaxTTL time.Duration `json:"max_ttl" mapstructure:"max_ttl" structs:"max_ttl"` } const pathRoleHelpSyn = ` diff --git a/builtin/logical/database/plugin.go b/builtin/logical/database/plugin.go deleted file mode 100644 index 5a6a8e328598..000000000000 --- a/builtin/logical/database/plugin.go +++ /dev/null @@ -1,324 +0,0 @@ -package database - -import ( - "errors" - "fmt" - "net/rpc" - "sync" - "time" - - "github.com/hashicorp/go-plugin" - "github.com/hashicorp/vault/helper/pluginutil" - "github.com/hashicorp/vault/logical" - log "github.com/mgutz/logxi/v1" -) - -var ( - ErrEmptyPluginName = errors.New("empty plugin name") -) - -// PluginFactory is used to build plugin database types. It wraps the database -// object in a logging and metrics middleware. -func PluginFactory(conf *DatabaseConfig, sys logical.SystemView, logger log.Logger) (DatabaseType, error) { - if conf.PluginName == "" { - return nil, ErrEmptyPluginName - } - - pluginMeta, err := sys.LookupPlugin(conf.PluginName) - if err != nil { - return nil, err - } - - db, err := newPluginClient(sys, pluginMeta) - if err != nil { - return nil, err - } - - // Wrap with metrics middleware - db = &databaseMetricsMiddleware{ - next: db, - typeStr: db.Type(), - } - - // Wrap with tracing middleware - db = &databaseTracingMiddleware{ - next: db, - typeStr: db.Type(), - logger: logger, - } - - return db, nil -} - -// handshakeConfigs are used to just do a basic handshake between -// a plugin and host. If the handshake fails, a user friendly error is shown. -// This prevents users from executing bad plugins or executing a plugin -// directory. It is a UX feature, not a security feature. -var handshakeConfig = plugin.HandshakeConfig{ - ProtocolVersion: 1, - MagicCookieKey: "VAULT_DATABASE_PLUGIN", - MagicCookieValue: "926a0820-aea2-be28-51d6-83cdf00e8edb", -} - -type DatabasePlugin struct { - impl DatabaseType -} - -func (d DatabasePlugin) Server(*plugin.MuxBroker) (interface{}, error) { - return &databasePluginRPCServer{impl: d.impl}, nil -} - -func (DatabasePlugin) Client(b *plugin.MuxBroker, c *rpc.Client) (interface{}, error) { - return &databasePluginRPCClient{client: c}, nil -} - -// DatabasePluginClient embeds a databasePluginRPCClient and wraps it's close -// method to also call Kill() on the plugin.Client. -type DatabasePluginClient struct { - client *plugin.Client - sync.Mutex - - *databasePluginRPCClient -} - -func (dc *DatabasePluginClient) Close() error { - err := dc.databasePluginRPCClient.Close() - dc.client.Kill() - - return err -} - -// newPluginClient returns a databaseRPCClient with a connection to a running -// plugin. The client is wrapped in a DatabasePluginClient object to ensure the -// plugin is killed on call of Close(). -func newPluginClient(sys pluginutil.Wrapper, pluginRunner *pluginutil.PluginRunner) (DatabaseType, error) { - // pluginMap is the map of plugins we can dispense. - var pluginMap = map[string]plugin.Plugin{ - "database": new(DatabasePlugin), - } - - client, err := pluginRunner.Run(sys, pluginMap, handshakeConfig, []string{}) - if err != nil { - return nil, err - } - - // Connect via RPC - rpcClient, err := client.Client() - if err != nil { - return nil, err - } - - // Request the plugin - raw, err := rpcClient.Dispense("database") - if err != nil { - return nil, err - } - - // We should have a Greeter now! This feels like a normal interface - // implementation but is in fact over an RPC connection. - databaseRPC := raw.(*databasePluginRPCClient) - - return &DatabasePluginClient{ - client: client, - databasePluginRPCClient: databaseRPC, - }, nil -} - -// NewPluginServer is called from within a plugin and wraps the provided -// DatabaseType implimentation in a databasePluginRPCServer object and starts a -// RPC server. -func NewPluginServer(db DatabaseType) { - dbPlugin := &DatabasePlugin{ - impl: db, - } - - // pluginMap is the map of plugins we can dispense. - var pluginMap = map[string]plugin.Plugin{ - "database": dbPlugin, - } - - plugin.Serve(&plugin.ServeConfig{ - HandshakeConfig: handshakeConfig, - Plugins: pluginMap, - TLSProvider: pluginutil.VaultPluginTLSProvider, - }) -} - -// ---- RPC client domain ---- - -// databasePluginRPCClient impliments DatabaseType and is used on the client to -// make RPC calls to a plugin. -type databasePluginRPCClient struct { - client *rpc.Client -} - -func (dr *databasePluginRPCClient) Type() string { - var dbType string - //TODO: catch error - dr.client.Call("Plugin.Type", struct{}{}, &dbType) - - return fmt.Sprintf("plugin-%s", dbType) -} - -func (dr *databasePluginRPCClient) CreateUser(statements Statements, username, password, expiration string) error { - req := CreateUserRequest{ - Statements: statements, - Username: username, - Password: password, - Expiration: expiration, - } - - err := dr.client.Call("Plugin.CreateUser", req, &struct{}{}) - - return err -} - -func (dr *databasePluginRPCClient) RenewUser(statements Statements, username, expiration string) error { - req := RenewUserRequest{ - Statements: statements, - Username: username, - Expiration: expiration, - } - - err := dr.client.Call("Plugin.RenewUser", req, &struct{}{}) - - return err -} - -func (dr *databasePluginRPCClient) RevokeUser(statements Statements, username string) error { - req := RevokeUserRequest{ - Statements: statements, - Username: username, - } - - err := dr.client.Call("Plugin.RevokeUser", req, &struct{}{}) - - return err -} - -func (dr *databasePluginRPCClient) Initialize(conf map[string]interface{}) error { - err := dr.client.Call("Plugin.Initialize", conf, &struct{}{}) - - return err -} - -func (dr *databasePluginRPCClient) Close() error { - err := dr.client.Call("Plugin.Close", struct{}{}, &struct{}{}) - - return err -} - -func (dr *databasePluginRPCClient) GenerateUsername(displayName string) (string, error) { - resp := &GenerateUsernameResponse{} - err := dr.client.Call("Plugin.GenerateUsername", displayName, resp) - - return resp.Username, err -} - -func (dr *databasePluginRPCClient) GeneratePassword() (string, error) { - resp := &GeneratePasswordResponse{} - err := dr.client.Call("Plugin.GeneratePassword", struct{}{}, resp) - - return resp.Password, err -} - -func (dr *databasePluginRPCClient) GenerateExpiration(duration time.Duration) (string, error) { - resp := &GenerateExpirationResponse{} - err := dr.client.Call("Plugin.GenerateExpiration", duration, resp) - - return resp.Expiration, err -} - -// ---- RPC server domain ---- - -// databasePluginRPCServer impliments DatabaseType and is run inside a plugin -type databasePluginRPCServer struct { - impl DatabaseType -} - -func (ds *databasePluginRPCServer) Type(_ struct{}, resp *string) error { - *resp = ds.impl.Type() - return nil -} - -func (ds *databasePluginRPCServer) CreateUser(args *CreateUserRequest, _ *struct{}) error { - err := ds.impl.CreateUser(args.Statements, args.Username, args.Password, args.Expiration) - - return err -} - -func (ds *databasePluginRPCServer) RenewUser(args *RenewUserRequest, _ *struct{}) error { - err := ds.impl.RenewUser(args.Statements, args.Username, args.Expiration) - - return err -} - -func (ds *databasePluginRPCServer) RevokeUser(args *RevokeUserRequest, _ *struct{}) error { - err := ds.impl.RevokeUser(args.Statements, args.Username) - - return err -} - -func (ds *databasePluginRPCServer) Initialize(args map[string]interface{}, _ *struct{}) error { - err := ds.impl.Initialize(args) - - return err -} - -func (ds *databasePluginRPCServer) Close(_ struct{}, _ *struct{}) error { - ds.impl.Close() - return nil -} - -func (ds *databasePluginRPCServer) GenerateUsername(args string, resp *GenerateUsernameResponse) error { - var err error - resp.Username, err = ds.impl.GenerateUsername(args) - - return err -} - -func (ds *databasePluginRPCServer) GeneratePassword(_ struct{}, resp *GeneratePasswordResponse) error { - var err error - resp.Password, err = ds.impl.GeneratePassword() - - return err -} - -func (ds *databasePluginRPCServer) GenerateExpiration(args time.Duration, resp *GenerateExpirationResponse) error { - var err error - resp.Expiration, err = ds.impl.GenerateExpiration(args) - - return err -} - -// ---- Request Args Domain ---- - -type CreateUserRequest struct { - Statements Statements - Username string - Password string - Expiration string -} - -type RenewUserRequest struct { - Statements Statements - Username string - Expiration string -} - -type RevokeUserRequest struct { - Statements Statements - Username string -} - -// ---- Response Args Domain ---- - -type GenerateUsernameResponse struct { - Username string -} -type GenerateExpirationResponse struct { - Expiration string -} -type GeneratePasswordResponse struct { - Password string -} diff --git a/helper/pluginutil/runner.go b/helper/pluginutil/runner.go index 143a4c839145..90569dd9ad1f 100644 --- a/helper/pluginutil/runner.go +++ b/helper/pluginutil/runner.go @@ -12,6 +12,11 @@ type Looker interface { LookupPlugin(string) (*PluginRunner, error) } +type LookWrapper interface { + Looker + Wrapper +} + type PluginRunner struct { Name string `json:"name"` Command string `json:"command"` From 9ae5a2aedee98b68fffddabcbc0fc9a0d1d68f34 Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Fri, 7 Apr 2017 15:50:03 -0700 Subject: [PATCH 060/152] Add backend test --- builtin/logical/database/backend_test.go | 567 ++++++++++++++++++ .../database/path_config_connection.go | 2 + builtin/logical/database/path_roles.go | 2 +- builtin/logical/database/secret_creds.go | 2 +- command/plugin-exec.go | 2 +- helper/builtinplugins/builtin.go | 19 +- vault/plugin_catalog.go | 2 +- vault/testing.go | 45 +- 8 files changed, 633 insertions(+), 8 deletions(-) create mode 100644 builtin/logical/database/backend_test.go diff --git a/builtin/logical/database/backend_test.go b/builtin/logical/database/backend_test.go new file mode 100644 index 000000000000..5cb84476d4d4 --- /dev/null +++ b/builtin/logical/database/backend_test.go @@ -0,0 +1,567 @@ +package database + +import ( + "database/sql" + "errors" + "fmt" + "log" + "net" + "os" + "reflect" + "strings" + "sync" + "testing" + + "github.com/hashicorp/vault/builtin/logical/database/dbplugin" + "github.com/hashicorp/vault/helper/builtinplugins" + "github.com/hashicorp/vault/helper/pluginutil" + "github.com/hashicorp/vault/http" + "github.com/hashicorp/vault/logical" + "github.com/hashicorp/vault/vault" + "github.com/lib/pq" + "github.com/mitchellh/mapstructure" + dockertest "gopkg.in/ory-am/dockertest.v3" +) + +var ( + testImagePull sync.Once +) + +func preparePostgresTestContainer(t *testing.T, s logical.Storage, b logical.Backend) (cleanup func(), retURL string) { + if os.Getenv("PG_URL") != "" { + return func() {}, os.Getenv("PG_URL") + } + + pool, err := dockertest.NewPool("") + if err != nil { + t.Fatalf("Failed to connect to docker: %s", err) + } + + resource, err := pool.Run("postgres", "latest", []string{"POSTGRES_PASSWORD=secret", "POSTGRES_DB=database"}) + if err != nil { + t.Fatalf("Could not start local PostgreSQL docker container: %s", err) + } + + cleanup = func() { + err := pool.Purge(resource) + if err != nil { + t.Fatalf("Failed to cleanup local container: %s", err) + } + } + + retURL = fmt.Sprintf("postgres://postgres:secret@localhost:%s/database?sslmode=disable", resource.GetPort("5432/tcp")) + + // exponential backoff-retry + if err = pool.Retry(func() error { + // This will cause a validation to run + resp, err := b.HandleRequest(&logical.Request{ + Storage: s, + Operation: logical.UpdateOperation, + Path: "config/postgresql", + Data: map[string]interface{}{ + "plugin_name": "postgresql-database-plugin", + "connection_url": retURL, + }, + }) + if err != nil || (resp != nil && resp.IsError()) { + // It's likely not up and running yet, so return error and try again + return fmt.Errorf("err:%s resp:%#v\n", err, resp) + } + if resp == nil { + t.Fatal("expected warning") + } + + return nil + }); err != nil { + t.Fatalf("Could not connect to PostgreSQL docker container: %s", err) + } + + return +} + +func getCore(t *testing.T) (*vault.Core, net.Listener, logical.SystemView, string) { + core, _, token, ln := vault.TestCoreUnsealedWithListener(t) + http.TestServerWithListener(t, ln, "", core) + sys := vault.TestDynamicSystemView(core) + vault.TestAddTestPlugin(t, core, "postgresql-database-plugin", fmt.Sprintf("%s -test.run=TestBackend_PluginMain", os.Args[0])) + + return core, ln, sys, token +} + +func TestBackend_PluginMain(t *testing.T) { + if os.Getenv(pluginutil.PluginUnwrapTokenEnv) == "" { + return + } + + f, _ := builtinplugins.BuiltinPlugins.Get("postgresql-database-plugin") + f() +} + +func TestBackend_config_connection(t *testing.T) { + var resp *logical.Response + var err error + _, ln, sys, _ := getCore(t) + defer ln.Close() + + config := logical.TestBackendConfig() + config.StorageView = &logical.InmemStorage{} + config.System = sys + b, err := Factory(config) + if err != nil { + t.Fatal(err) + } + defer b.Cleanup() + + configData := map[string]interface{}{ + "connection_url": "sample_connection_url", + "plugin_name": "postgresql-database-plugin", + "verify_connection": false, + } + + configReq := &logical.Request{ + Operation: logical.UpdateOperation, + Path: "config/plugin-test", + Storage: config.StorageView, + Data: configData, + } + resp, err = b.HandleRequest(configReq) + if err != nil || (resp != nil && resp.IsError()) { + t.Fatalf("err:%s resp:%#v\n", err, resp) + } + + expected := map[string]interface{}{ + "plugin_name": "postgresql-database-plugin", + "connection_details": configData, + } + configReq.Operation = logical.ReadOperation + resp, err = b.HandleRequest(configReq) + if err != nil || (resp != nil && resp.IsError()) { + t.Fatalf("err:%s resp:%#v\n", err, resp) + } + + delete(resp.Data["connection_details"].(map[string]interface{}), "name") + if !reflect.DeepEqual(expected, resp.Data) { + t.Fatalf("bad: expected:%#v\nactual:%#v\n", expected, resp.Data) + } +} + +func TestBackend_basic(t *testing.T) { + _, ln, sys, _ := getCore(t) + defer ln.Close() + + config := logical.TestBackendConfig() + config.StorageView = &logical.InmemStorage{} + config.System = sys + + b, err := Factory(config) + if err != nil { + t.Fatal(err) + } + defer b.Cleanup() + + cleanup, connURL := preparePostgresTestContainer(t, config.StorageView, b) + defer cleanup() + + // Configure a connection + data := map[string]interface{}{ + "connection_url": connURL, + "plugin_name": "postgresql-database-plugin", + } + req := &logical.Request{ + Operation: logical.UpdateOperation, + Path: "config/plugin-test", + Storage: config.StorageView, + Data: data, + } + resp, err := b.HandleRequest(req) + if err != nil || (resp != nil && resp.IsError()) { + t.Fatalf("err:%s resp:%#v\n", err, resp) + } + + // Create a role + data = map[string]interface{}{ + "db_name": "plugin-test", + "creation_statements": testRole, + "default_ttl": "5m", + "max_ttl": "10m", + } + req = &logical.Request{ + Operation: logical.UpdateOperation, + Path: "roles/plugin-role-test", + Storage: config.StorageView, + Data: data, + } + resp, err = b.HandleRequest(req) + if err != nil || (resp != nil && resp.IsError()) { + t.Fatalf("err:%s resp:%#v\n", err, resp) + } + + // Get creds + data = map[string]interface{}{} + req = &logical.Request{ + Operation: logical.ReadOperation, + Path: "creds/plugin-role-test", + Storage: config.StorageView, + Data: data, + } + credsResp, err := b.HandleRequest(req) + if err != nil || (credsResp != nil && credsResp.IsError()) { + t.Fatalf("err:%s resp:%#v\n", err, credsResp) + } + + if testCredsByCount(t, credsResp, connURL) != 2 { + t.Fatalf("Got wrong number of creds") + } + + // Revoke creds + resp, err = b.HandleRequest(&logical.Request{ + Operation: logical.RevokeOperation, + Storage: config.StorageView, + Secret: &logical.Secret{ + InternalData: map[string]interface{}{ + "secret_type": "creds", + "username": credsResp.Data["username"], + "role": "plugin-role-test", + }, + }, + }) + if err != nil || (resp != nil && resp.IsError()) { + t.Fatalf("err:%s resp:%#v\n", err, resp) + } + + if testCredsByCount(t, credsResp, connURL) != -1 { + t.Fatalf("Got wrong number of creds") + } + +} + +func TestBackend_roleCrud(t *testing.T) { + _, ln, sys, _ := getCore(t) + defer ln.Close() + + config := logical.TestBackendConfig() + config.StorageView = &logical.InmemStorage{} + config.System = sys + + b, err := Factory(config) + if err != nil { + t.Fatal(err) + } + defer b.Cleanup() + + cleanup, connURL := preparePostgresTestContainer(t, config.StorageView, b) + defer cleanup() + + // Configure a connection + data := map[string]interface{}{ + "connection_url": connURL, + "plugin_name": "postgresql-database-plugin", + } + req := &logical.Request{ + Operation: logical.UpdateOperation, + Path: "config/plugin-test", + Storage: config.StorageView, + Data: data, + } + resp, err := b.HandleRequest(req) + if err != nil || (resp != nil && resp.IsError()) { + t.Fatalf("err:%s resp:%#v\n", err, resp) + } + + // Create a role + data = map[string]interface{}{ + "db_name": "plugin-test", + "creation_statements": testRole, + "revocation_statements": defaultRevocationSQL, + "default_ttl": "5m", + "max_ttl": "10m", + } + req = &logical.Request{ + Operation: logical.UpdateOperation, + Path: "roles/plugin-role-test", + Storage: config.StorageView, + Data: data, + } + resp, err = b.HandleRequest(req) + if err != nil || (resp != nil && resp.IsError()) { + t.Fatalf("err:%s resp:%#v\n", err, resp) + } + + // Read the role + data = map[string]interface{}{} + req = &logical.Request{ + Operation: logical.ReadOperation, + Path: "roles/plugin-role-test", + Storage: config.StorageView, + Data: data, + } + resp, err = b.HandleRequest(req) + if err != nil || (resp != nil && resp.IsError()) { + t.Fatalf("err:%s resp:%#v\n", err, resp) + } + + expected := dbplugin.Statements{ + CreationStatements: testRole, + RevocationStatements: defaultRevocationSQL, + } + + var actual dbplugin.Statements + if err := mapstructure.Decode(resp.Data, &actual); err != nil { + t.Fatal(err) + } + + if !reflect.DeepEqual(expected, actual) { + t.Fatalf("Statements did not match, exepected %#v, got %#v", expected, actual) + } + + // Delete the role + data = map[string]interface{}{} + req = &logical.Request{ + Operation: logical.DeleteOperation, + Path: "roles/plugin-role-test", + Storage: config.StorageView, + Data: data, + } + resp, err = b.HandleRequest(req) + if err != nil || (resp != nil && resp.IsError()) { + t.Fatalf("err:%s resp:%#v\n", err, resp) + } + + // Read the role + data = map[string]interface{}{} + req = &logical.Request{ + Operation: logical.ReadOperation, + Path: "roles/plugin-role-test", + Storage: config.StorageView, + Data: data, + } + resp, err = b.HandleRequest(req) + if err != nil || (resp != nil && resp.IsError()) { + t.Fatalf("err:%s resp:%#v\n", err, resp) + } + + // Should be empty + if resp != nil { + t.Fatal("Expected response to be nil") + } +} + +func TestBackend_roleReadOnly(t *testing.T) { + _, ln, sys, _ := getCore(t) + defer ln.Close() + + config := logical.TestBackendConfig() + config.StorageView = &logical.InmemStorage{} + config.System = sys + + b, err := Factory(config) + if err != nil { + t.Fatal(err) + } + defer b.Cleanup() + + cleanup, connURL := preparePostgresTestContainer(t, config.StorageView, b) + defer cleanup() + + // Configure a connection + data := map[string]interface{}{ + "connection_url": connURL, + "plugin_name": "postgresql-database-plugin", + } + req := &logical.Request{ + Operation: logical.UpdateOperation, + Path: "config/plugin-test", + Storage: config.StorageView, + Data: data, + } + resp, err := b.HandleRequest(req) + if err != nil || (resp != nil && resp.IsError()) { + t.Fatalf("err:%s resp:%#v\n", err, resp) + } + + // Create a role + data = map[string]interface{}{ + "db_name": "plugin-test", + "creation_statements": testRole, + "default_ttl": "5m", + "max_ttl": "10m", + } + req = &logical.Request{ + Operation: logical.UpdateOperation, + Path: "roles/plugin-role-test", + Storage: config.StorageView, + Data: data, + } + resp, err = b.HandleRequest(req) + if err != nil || (resp != nil && resp.IsError()) { + t.Fatalf("err:%s resp:%#v\n", err, resp) + } + + // Create a readonly role + data = map[string]interface{}{ + "db_name": "plugin-test", + "creation_statements": testReadOnlyRole, + "default_ttl": "5m", + "max_ttl": "10m", + } + req = &logical.Request{ + Operation: logical.UpdateOperation, + Path: "roles/plugin-readonly-role-test", + Storage: config.StorageView, + Data: data, + } + resp, err = b.HandleRequest(req) + if err != nil || (resp != nil && resp.IsError()) { + t.Fatalf("err:%s resp:%#v\n", err, resp) + } + + // Get creds + data = map[string]interface{}{} + req = &logical.Request{ + Operation: logical.ReadOperation, + Path: "creds/plugin-role-test", + Storage: config.StorageView, + Data: data, + } + credsResp, err := b.HandleRequest(req) + if err != nil || (credsResp != nil && credsResp.IsError()) { + t.Fatalf("err:%s resp:%#v\n", err, credsResp) + } + + if i := testCredsByCount(t, credsResp, connURL); i != 2 { + t.Fatalf("Got wrong number of creds got %d, expected 2", i) + } + + // Get readonly creds + data = map[string]interface{}{} + req = &logical.Request{ + Operation: logical.ReadOperation, + Path: "creds/plugin-readonly-role-test", + Storage: config.StorageView, + Data: data, + } + readOnlyCredsResp, err := b.HandleRequest(req) + if err != nil || (credsResp != nil && credsResp.IsError()) { + t.Fatalf("err:%s resp:%#v\n", err, readOnlyCredsResp) + } + + if i := testCredsByCount(t, readOnlyCredsResp, connURL); i != 2 { + t.Fatalf("Got wrong number of creds got %d, expected 2", i) + } + + if err := testCreateTable(t, readOnlyCredsResp, connURL); err == nil { + t.Fatal("Read only creds should return error on table creation") + } + + if err := testCreateTable(t, credsResp, connURL); err != nil { + t.Fatalf("Error on table creation: %s", err) + } +} + +func testCredsByCount(t *testing.T, resp *logical.Response, connURL string) int { + var d struct { + Username string `mapstructure:"username"` + Password string `mapstructure:"password"` + } + if err := mapstructure.Decode(resp.Data, &d); err != nil { + t.Fatal(err) + } + log.Printf("[TRACE] Generated credentials: %v", d) + conn, err := pq.ParseURL(connURL) + + if err != nil { + t.Fatal(err) + } + + conn += " timezone=utc" + + db, err := sql.Open("postgres", conn) + if err != nil { + t.Fatal(err) + } + + returnedRows := func() int { + stmt, err := db.Prepare("SELECT DISTINCT schemaname FROM pg_tables WHERE has_table_privilege($1, 'information_schema.role_column_grants', 'select');") + if err != nil { + return -1 + } + defer stmt.Close() + + rows, err := stmt.Query(d.Username) + if err != nil { + return -1 + } + defer rows.Close() + + i := 0 + for rows.Next() { + i++ + } + return i + } + + return returnedRows() +} + +func testCreateTable(t *testing.T, resp *logical.Response, connURL string) error { + var d struct { + Username string `mapstructure:"username"` + Password string `mapstructure:"password"` + } + if err := mapstructure.Decode(resp.Data, &d); err != nil { + t.Fatal(err) + } + + connURL = strings.Replace(connURL, "postgres:secret", fmt.Sprintf("%s:%s", d.Username, d.Password), 1) + + fmt.Println(connURL) + log.Printf("[TRACE] Generated credentials: %v", d) + conn, err := pq.ParseURL(connURL) + if err != nil { + t.Fatal(err) + } + + conn += " timezone=utc" + + db, err := sql.Open("postgres", conn) + if err != nil { + t.Fatal(err) + } + + r, err := db.Exec("CREATE TABLE test1 (id SERIAL PRIMARY KEY);") + if err != nil { + return err + } + + if i, _ := r.RowsAffected(); i != 1 { + return errors.New("Did not create db") + } + + return nil +} + +const testRole = ` +CREATE ROLE "{{name}}" WITH + LOGIN + PASSWORD '{{password}}' + VALID UNTIL '{{expiration}}'; +GRANT ALL PRIVILEGES ON ALL TABLES IN SCHEMA public TO "{{name}}"; +` + +const testReadOnlyRole = ` +CREATE ROLE "{{name}}" WITH + LOGIN + PASSWORD '{{password}}' + VALID UNTIL '{{expiration}}'; +REVOKE ALL ON SCHEMA public FROM "{{name}}"; +GRANT SELECT ON ALL TABLES IN SCHEMA public TO "{{name}}"; +GRANT SELECT ON ALL SEQUENCES IN SCHEMA public TO "{{name}}"; +` + +const defaultRevocationSQL = ` +REVOKE ALL PRIVILEGES ON ALL TABLES IN SCHEMA public FROM {{name}}; +REVOKE ALL PRIVILEGES ON ALL SEQUENCES IN SCHEMA public FROM {{name}}; +REVOKE USAGE ON SCHEMA public FROM {{name}}; + +DROP ROLE IF EXISTS {{name}}; +` diff --git a/builtin/logical/database/path_config_connection.go b/builtin/logical/database/path_config_connection.go index 4af6e70a0514..1b8a6583152d 100644 --- a/builtin/logical/database/path_config_connection.go +++ b/builtin/logical/database/path_config_connection.go @@ -172,10 +172,12 @@ func (b *databaseBackend) connectionWriteHandler() framework.OperationFunc { err = db.Initialize(config.ConnectionDetails) if err != nil { if !strings.Contains(err.Error(), "Error Initializing Connection") { + db.Close() return logical.ErrorResponse(fmt.Sprintf("Error creating database object: %s", err)), nil } if verifyConnection { + db.Close() return logical.ErrorResponse("Could not verify connection"), nil } } diff --git a/builtin/logical/database/path_roles.go b/builtin/logical/database/path_roles.go index b3ef6f753db8..a6989df2481c 100644 --- a/builtin/logical/database/path_roles.go +++ b/builtin/logical/database/path_roles.go @@ -105,7 +105,7 @@ func (b *databaseBackend) pathRoleRead(req *logical.Request, data *framework.Fie return &logical.Response{ Data: map[string]interface{}{ - "creation_statments": role.Statements.CreationStatements, + "creation_statements": role.Statements.CreationStatements, "revocation_statements": role.Statements.RevocationStatements, "rollback_statements": role.Statements.RollbackStatements, "renew_statements": role.Statements.RenewStatements, diff --git a/builtin/logical/database/secret_creds.go b/builtin/logical/database/secret_creds.go index 2b63ea1f89b0..353541c0cc35 100644 --- a/builtin/logical/database/secret_creds.go +++ b/builtin/logical/database/secret_creds.go @@ -81,7 +81,7 @@ func (b *databaseBackend) secretCredsRevoke(req *logical.Request, d *framework.F roleNameRaw, ok := req.Secret.InternalData["role"] if !ok { - return nil, fmt.Errorf("could not find role with name: %s", req.Secret.InternalData["role"]) + return nil, fmt.Errorf("no role name was provided") } role, err := b.Role(req.Storage, roleNameRaw.(string)) diff --git a/command/plugin-exec.go b/command/plugin-exec.go index 70bc8ae1d4d3..575be14b7d91 100644 --- a/command/plugin-exec.go +++ b/command/plugin-exec.go @@ -29,7 +29,7 @@ func (c *PluginExec) Run(args []string) int { pluginName := args[0] - runner, ok := builtinplugins.BuiltinPlugins[pluginName] + runner, ok := builtinplugins.BuiltinPlugins.Get(pluginName) if !ok { c.Ui.Error(fmt.Sprintf( "No plugin with the name %s found", pluginName)) diff --git a/helper/builtinplugins/builtin.go b/helper/builtinplugins/builtin.go index ceaf10edf975..ba3769c900a1 100644 --- a/helper/builtinplugins/builtin.go +++ b/helper/builtinplugins/builtin.go @@ -5,7 +5,20 @@ import ( "github.com/hashicorp/vault-plugins/database/postgresql" ) -var BuiltinPlugins = map[string]func() error{ - "mysql-database-plugin": mysql.Run, - "postgresql-database-plugin": postgresql.Run, +var BuiltinPlugins *builtinPlugins = &builtinPlugins{ + plugins: map[string]func() error{ + "mysql-database-plugin": mysql.Run, + "postgresql-database-plugin": postgresql.Run, + }, +} + +// The list of builtin plugins should not be changed by any other package, so we +// store them in an unexported variable in this unexported struct. +type builtinPlugins struct { + plugins map[string]func() error +} + +func (b *builtinPlugins) Get(name string) (func() error, bool) { + f, ok := b.plugins[name] + return f, ok } diff --git a/vault/plugin_catalog.go b/vault/plugin_catalog.go index b9c15db22a52..a42f85ec115a 100644 --- a/vault/plugin_catalog.go +++ b/vault/plugin_catalog.go @@ -54,7 +54,7 @@ func (c *PluginCatalog) Get(name string) (*pluginutil.PluginRunner, error) { } // Look for builtin plugins - if _, ok := builtinplugins.BuiltinPlugins[name]; !ok { + if _, ok := builtinplugins.BuiltinPlugins.Get(name); !ok { return nil, fmt.Errorf("no plugin found with name: %s", name) } diff --git a/vault/testing.go b/vault/testing.go index 7b914bbdbbee..fdf55b4e595a 100644 --- a/vault/testing.go +++ b/vault/testing.go @@ -8,9 +8,13 @@ import ( "crypto/x509" "encoding/pem" "fmt" + "io" "net" "net/http" + "os" "os/exec" + "path/filepath" + "strings" "testing" "time" @@ -306,7 +310,46 @@ func TestKeyCopy(key []byte) []byte { } func TestDynamicSystemView(c *Core) *dynamicSystemView { - return &dynamicSystemView{c, nil} + me := &MountEntry{ + Config: MountConfig{ + DefaultLeaseTTL: 24 * time.Hour, + MaxLeaseTTL: 2 * 24 * time.Hour, + }, + } + + return &dynamicSystemView{c, me} +} + +func TestAddTestPlugin(t testing.TB, c *Core, name, command string) { + parts := strings.Split(command, " ") + + file, err := os.Open(parts[0]) + if err != nil { + t.Fatal(err) + } + defer file.Close() + + hash := sha256.New() + + _, err = io.Copy(hash, file) + if err != nil { + t.Fatal(err) + } + + sum := hash.Sum(nil) + c.pluginCatalog.directory, err = filepath.EvalSymlinks(parts[0]) + if err != nil { + t.Fatal(err) + } + c.pluginCatalog.directory = filepath.Dir(c.pluginCatalog.directory) + + parts[0] = filepath.Base(parts[0]) + command = strings.Join(parts, " ") + + err = c.pluginCatalog.Set(name, command, sum) + if err != nil { + t.Fatal(err) + } } var testLogicalBackends = map[string]logical.Factory{} From 3c1c388589ecfb8309dce8385db565ddfb3c2e13 Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Mon, 10 Apr 2017 10:35:16 -0700 Subject: [PATCH 061/152] Update backend tests --- builtin/logical/database/backend_test.go | 210 ++++++++---------- .../database/path_config_connection.go | 2 +- 2 files changed, 96 insertions(+), 116 deletions(-) diff --git a/builtin/logical/database/backend_test.go b/builtin/logical/database/backend_test.go index 5cb84476d4d4..fc41cf3cd83c 100644 --- a/builtin/logical/database/backend_test.go +++ b/builtin/logical/database/backend_test.go @@ -2,13 +2,11 @@ package database import ( "database/sql" - "errors" "fmt" "log" "net" "os" "reflect" - "strings" "sync" "testing" @@ -209,8 +207,8 @@ func TestBackend_basic(t *testing.T) { t.Fatalf("err:%s resp:%#v\n", err, credsResp) } - if testCredsByCount(t, credsResp, connURL) != 2 { - t.Fatalf("Got wrong number of creds") + if !testCredsExist(t, credsResp, connURL) { + t.Fatalf("Creds should exist") } // Revoke creds @@ -229,13 +227,13 @@ func TestBackend_basic(t *testing.T) { t.Fatalf("err:%s resp:%#v\n", err, resp) } - if testCredsByCount(t, credsResp, connURL) != -1 { - t.Fatalf("Got wrong number of creds") + if testCredsExist(t, credsResp, connURL) { + t.Fatalf("Creds should not exist") } } -func TestBackend_roleCrud(t *testing.T) { +func TestBackend_connectionCrud(t *testing.T) { _, ln, sys, _ := getCore(t) defer ln.Close() @@ -254,8 +252,9 @@ func TestBackend_roleCrud(t *testing.T) { // Configure a connection data := map[string]interface{}{ - "connection_url": connURL, - "plugin_name": "postgresql-database-plugin", + "connection_url": "test", + "plugin_name": "postgresql-database-plugin", + "verify_connection": false, } req := &logical.Request{ Operation: logical.UpdateOperation, @@ -287,11 +286,14 @@ func TestBackend_roleCrud(t *testing.T) { t.Fatalf("err:%s resp:%#v\n", err, resp) } - // Read the role - data = map[string]interface{}{} + // Update the connection + data = map[string]interface{}{ + "connection_url": connURL, + "plugin_name": "postgresql-database-plugin", + } req = &logical.Request{ - Operation: logical.ReadOperation, - Path: "roles/plugin-role-test", + Operation: logical.UpdateOperation, + Path: "config/plugin-test", Storage: config.StorageView, Data: data, } @@ -300,25 +302,27 @@ func TestBackend_roleCrud(t *testing.T) { t.Fatalf("err:%s resp:%#v\n", err, resp) } - expected := dbplugin.Statements{ - CreationStatements: testRole, - RevocationStatements: defaultRevocationSQL, + // Read connection + expected := map[string]interface{}{ + "plugin_name": "postgresql-database-plugin", + "connection_details": data, } - - var actual dbplugin.Statements - if err := mapstructure.Decode(resp.Data, &actual); err != nil { - t.Fatal(err) + req.Operation = logical.ReadOperation + resp, err = b.HandleRequest(req) + if err != nil || (resp != nil && resp.IsError()) { + t.Fatalf("err:%s resp:%#v\n", err, resp) } - if !reflect.DeepEqual(expected, actual) { - t.Fatalf("Statements did not match, exepected %#v, got %#v", expected, actual) + delete(resp.Data["connection_details"].(map[string]interface{}), "name") + if !reflect.DeepEqual(expected, resp.Data) { + t.Fatalf("bad: expected:%#v\nactual:%#v\n", expected, resp.Data) } - // Delete the role + // Reset Connection data = map[string]interface{}{} req = &logical.Request{ - Operation: logical.DeleteOperation, - Path: "roles/plugin-role-test", + Operation: logical.UpdateOperation, + Path: "reset/plugin-test", Storage: config.StorageView, Data: data, } @@ -327,11 +331,28 @@ func TestBackend_roleCrud(t *testing.T) { t.Fatalf("err:%s resp:%#v\n", err, resp) } - // Read the role + // Get creds data = map[string]interface{}{} req = &logical.Request{ Operation: logical.ReadOperation, - Path: "roles/plugin-role-test", + Path: "creds/plugin-role-test", + Storage: config.StorageView, + Data: data, + } + credsResp, err := b.HandleRequest(req) + if err != nil || (credsResp != nil && credsResp.IsError()) { + t.Fatalf("err:%s resp:%#v\n", err, credsResp) + } + + if !testCredsExist(t, credsResp, connURL) { + t.Fatalf("Creds should exist") + } + + // Delete Connection + data = map[string]interface{}{} + req = &logical.Request{ + Operation: logical.DeleteOperation, + Path: "config/plugin-test", Storage: config.StorageView, Data: data, } @@ -340,13 +361,20 @@ func TestBackend_roleCrud(t *testing.T) { t.Fatalf("err:%s resp:%#v\n", err, resp) } + // Read connection + req.Operation = logical.ReadOperation + resp, err = b.HandleRequest(req) + if err != nil || (resp != nil && resp.IsError()) { + t.Fatalf("err:%s resp:%#v\n", err, resp) + } + // Should be empty if resp != nil { t.Fatal("Expected response to be nil") } } -func TestBackend_roleReadOnly(t *testing.T) { +func TestBackend_roleCrud(t *testing.T) { _, ln, sys, _ := getCore(t) defer ln.Close() @@ -381,10 +409,11 @@ func TestBackend_roleReadOnly(t *testing.T) { // Create a role data = map[string]interface{}{ - "db_name": "plugin-test", - "creation_statements": testRole, - "default_ttl": "5m", - "max_ttl": "10m", + "db_name": "plugin-test", + "creation_statements": testRole, + "revocation_statements": defaultRevocationSQL, + "default_ttl": "5m", + "max_ttl": "10m", } req = &logical.Request{ Operation: logical.UpdateOperation, @@ -397,16 +426,11 @@ func TestBackend_roleReadOnly(t *testing.T) { t.Fatalf("err:%s resp:%#v\n", err, resp) } - // Create a readonly role - data = map[string]interface{}{ - "db_name": "plugin-test", - "creation_statements": testReadOnlyRole, - "default_ttl": "5m", - "max_ttl": "10m", - } + // Read the role + data = map[string]interface{}{} req = &logical.Request{ - Operation: logical.UpdateOperation, - Path: "roles/plugin-readonly-role-test", + Operation: logical.ReadOperation, + Path: "roles/plugin-role-test", Storage: config.StorageView, Data: data, } @@ -415,50 +439,53 @@ func TestBackend_roleReadOnly(t *testing.T) { t.Fatalf("err:%s resp:%#v\n", err, resp) } - // Get creds - data = map[string]interface{}{} - req = &logical.Request{ - Operation: logical.ReadOperation, - Path: "creds/plugin-role-test", - Storage: config.StorageView, - Data: data, + expected := dbplugin.Statements{ + CreationStatements: testRole, + RevocationStatements: defaultRevocationSQL, } - credsResp, err := b.HandleRequest(req) - if err != nil || (credsResp != nil && credsResp.IsError()) { - t.Fatalf("err:%s resp:%#v\n", err, credsResp) + + var actual dbplugin.Statements + if err := mapstructure.Decode(resp.Data, &actual); err != nil { + t.Fatal(err) } - if i := testCredsByCount(t, credsResp, connURL); i != 2 { - t.Fatalf("Got wrong number of creds got %d, expected 2", i) + if !reflect.DeepEqual(expected, actual) { + t.Fatalf("Statements did not match, exepected %#v, got %#v", expected, actual) } - // Get readonly creds + // Delete the role data = map[string]interface{}{} req = &logical.Request{ - Operation: logical.ReadOperation, - Path: "creds/plugin-readonly-role-test", + Operation: logical.DeleteOperation, + Path: "roles/plugin-role-test", Storage: config.StorageView, Data: data, } - readOnlyCredsResp, err := b.HandleRequest(req) - if err != nil || (credsResp != nil && credsResp.IsError()) { - t.Fatalf("err:%s resp:%#v\n", err, readOnlyCredsResp) + resp, err = b.HandleRequest(req) + if err != nil || (resp != nil && resp.IsError()) { + t.Fatalf("err:%s resp:%#v\n", err, resp) } - if i := testCredsByCount(t, readOnlyCredsResp, connURL); i != 2 { - t.Fatalf("Got wrong number of creds got %d, expected 2", i) + // Read the role + data = map[string]interface{}{} + req = &logical.Request{ + Operation: logical.ReadOperation, + Path: "roles/plugin-role-test", + Storage: config.StorageView, + Data: data, } - - if err := testCreateTable(t, readOnlyCredsResp, connURL); err == nil { - t.Fatal("Read only creds should return error on table creation") + resp, err = b.HandleRequest(req) + if err != nil || (resp != nil && resp.IsError()) { + t.Fatalf("err:%s resp:%#v\n", err, resp) } - if err := testCreateTable(t, credsResp, connURL); err != nil { - t.Fatalf("Error on table creation: %s", err) + // Should be empty + if resp != nil { + t.Fatal("Expected response to be nil") } } -func testCredsByCount(t *testing.T, resp *logical.Response, connURL string) int { +func testCredsExist(t *testing.T, resp *logical.Response, connURL string) bool { var d struct { Username string `mapstructure:"username"` Password string `mapstructure:"password"` @@ -500,44 +527,7 @@ func testCredsByCount(t *testing.T, resp *logical.Response, connURL string) int return i } - return returnedRows() -} - -func testCreateTable(t *testing.T, resp *logical.Response, connURL string) error { - var d struct { - Username string `mapstructure:"username"` - Password string `mapstructure:"password"` - } - if err := mapstructure.Decode(resp.Data, &d); err != nil { - t.Fatal(err) - } - - connURL = strings.Replace(connURL, "postgres:secret", fmt.Sprintf("%s:%s", d.Username, d.Password), 1) - - fmt.Println(connURL) - log.Printf("[TRACE] Generated credentials: %v", d) - conn, err := pq.ParseURL(connURL) - if err != nil { - t.Fatal(err) - } - - conn += " timezone=utc" - - db, err := sql.Open("postgres", conn) - if err != nil { - t.Fatal(err) - } - - r, err := db.Exec("CREATE TABLE test1 (id SERIAL PRIMARY KEY);") - if err != nil { - return err - } - - if i, _ := r.RowsAffected(); i != 1 { - return errors.New("Did not create db") - } - - return nil + return returnedRows() == 2 } const testRole = ` @@ -548,16 +538,6 @@ CREATE ROLE "{{name}}" WITH GRANT ALL PRIVILEGES ON ALL TABLES IN SCHEMA public TO "{{name}}"; ` -const testReadOnlyRole = ` -CREATE ROLE "{{name}}" WITH - LOGIN - PASSWORD '{{password}}' - VALID UNTIL '{{expiration}}'; -REVOKE ALL ON SCHEMA public FROM "{{name}}"; -GRANT SELECT ON ALL TABLES IN SCHEMA public TO "{{name}}"; -GRANT SELECT ON ALL SEQUENCES IN SCHEMA public TO "{{name}}"; -` - const defaultRevocationSQL = ` REVOKE ALL PRIVILEGES ON ALL TABLES IN SCHEMA public FROM {{name}}; REVOKE ALL PRIVILEGES ON ALL SEQUENCES IN SCHEMA public FROM {{name}}; diff --git a/builtin/logical/database/path_config_connection.go b/builtin/logical/database/path_config_connection.go index 1b8a6583152d..7589669a4625 100644 --- a/builtin/logical/database/path_config_connection.go +++ b/builtin/logical/database/path_config_connection.go @@ -171,7 +171,7 @@ func (b *databaseBackend) connectionWriteHandler() framework.OperationFunc { err = db.Initialize(config.ConnectionDetails) if err != nil { - if !strings.Contains(err.Error(), "Error Initializing Connection") { + if !strings.Contains(err.Error(), "error initalizing connection") { db.Close() return logical.ErrorResponse(fmt.Sprintf("Error creating database object: %s", err)), nil } From 73f66f89cda251de7564194b19932e2d49e25302 Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Mon, 10 Apr 2017 12:24:16 -0700 Subject: [PATCH 062/152] Update the interface for plugins removing functions for creating creds --- builtin/logical/database/dbplugin/client.go | 37 ++------ .../database/dbplugin/databasemiddleware.go | 93 ++----------------- builtin/logical/database/dbplugin/plugin.go | 24 ++--- builtin/logical/database/dbplugin/server.go | 28 +----- builtin/logical/database/path_role_create.go | 19 +--- builtin/logical/database/secret_creds.go | 4 +- 6 files changed, 28 insertions(+), 177 deletions(-) diff --git a/builtin/logical/database/dbplugin/client.go b/builtin/logical/database/dbplugin/client.go index db6b3d1fdaa4..0dae61d27647 100644 --- a/builtin/logical/database/dbplugin/client.go +++ b/builtin/logical/database/dbplugin/client.go @@ -78,20 +78,20 @@ func (dr *databasePluginRPCClient) Type() string { return fmt.Sprintf("plugin-%s", dbType) } -func (dr *databasePluginRPCClient) CreateUser(statements Statements, username, password, expiration string) error { +func (dr *databasePluginRPCClient) CreateUser(statements Statements, usernamePrefix string, expiration time.Time) (username string, password string, err error) { req := CreateUserRequest{ - Statements: statements, - Username: username, - Password: password, - Expiration: expiration, + Statements: statements, + UsernamePrefix: usernamePrefix, + Expiration: expiration, } - err := dr.client.Call("Plugin.CreateUser", req, &struct{}{}) + var resp CreateUserResponse + err = dr.client.Call("Plugin.CreateUser", req, &resp) - return err + return resp.Username, resp.Password, err } -func (dr *databasePluginRPCClient) RenewUser(statements Statements, username, expiration string) error { +func (dr *databasePluginRPCClient) RenewUser(statements Statements, username string, expiration time.Time) error { req := RenewUserRequest{ Statements: statements, Username: username, @@ -125,24 +125,3 @@ func (dr *databasePluginRPCClient) Close() error { return err } - -func (dr *databasePluginRPCClient) GenerateUsername(displayName string) (string, error) { - resp := &GenerateUsernameResponse{} - err := dr.client.Call("Plugin.GenerateUsername", displayName, resp) - - return resp.Username, err -} - -func (dr *databasePluginRPCClient) GeneratePassword() (string, error) { - resp := &GeneratePasswordResponse{} - err := dr.client.Call("Plugin.GeneratePassword", struct{}{}, resp) - - return resp.Password, err -} - -func (dr *databasePluginRPCClient) GenerateExpiration(duration time.Duration) (string, error) { - resp := &GenerateExpirationResponse{} - err := dr.client.Call("Plugin.GenerateExpiration", duration, resp) - - return resp.Expiration, err -} diff --git a/builtin/logical/database/dbplugin/databasemiddleware.go b/builtin/logical/database/dbplugin/databasemiddleware.go index b4a9809508b0..2748f2f11f47 100644 --- a/builtin/logical/database/dbplugin/databasemiddleware.go +++ b/builtin/logical/database/dbplugin/databasemiddleware.go @@ -20,7 +20,7 @@ func (mw *databaseTracingMiddleware) Type() string { return mw.next.Type() } -func (mw *databaseTracingMiddleware) CreateUser(statements Statements, username, password, expiration string) (err error) { +func (mw *databaseTracingMiddleware) CreateUser(statements Statements, usernamePrefix string, expiration time.Time) (username string, password string, err error) { if mw.logger.IsTrace() { defer func(then time.Time) { mw.logger.Trace("database/CreateUser: finished", "type", mw.typeStr, "err", err, "took", time.Since(then)) @@ -28,10 +28,10 @@ func (mw *databaseTracingMiddleware) CreateUser(statements Statements, username, mw.logger.Trace("database/CreateUser: starting", "type", mw.typeStr) } - return mw.next.CreateUser(statements, username, password, expiration) + return mw.next.CreateUser(statements, usernamePrefix, expiration) } -func (mw *databaseTracingMiddleware) RenewUser(statements Statements, username, expiration string) (err error) { +func (mw *databaseTracingMiddleware) RenewUser(statements Statements, username string, expiration time.Time) (err error) { if mw.logger.IsTrace() { defer func(then time.Time) { mw.logger.Trace("database/RenewUser: finished", "type", mw.typeStr, "err", err, "took", time.Since(then)) @@ -75,39 +75,6 @@ func (mw *databaseTracingMiddleware) Close() (err error) { return mw.next.Close() } -func (mw *databaseTracingMiddleware) GenerateUsername(displayName string) (_ string, err error) { - if mw.logger.IsTrace() { - defer func(then time.Time) { - mw.logger.Trace("database/GenerateUsername: finished", "type", mw.typeStr, "err", err, "took", time.Since(then)) - }(time.Now()) - - mw.logger.Trace("database/GenerateUsername: starting", "type", mw.typeStr) - } - return mw.next.GenerateUsername(displayName) -} - -func (mw *databaseTracingMiddleware) GeneratePassword() (_ string, err error) { - if mw.logger.IsTrace() { - defer func(then time.Time) { - mw.logger.Trace("database/GeneratePassword: finished", "type", mw.typeStr, "err", err, "took", time.Since(then)) - }(time.Now()) - - mw.logger.Trace("database/GeneratePassword: starting", "type", mw.typeStr) - } - return mw.next.GeneratePassword() -} - -func (mw *databaseTracingMiddleware) GenerateExpiration(duration time.Duration) (_ string, err error) { - if mw.logger.IsTrace() { - defer func(then time.Time) { - mw.logger.Trace("database/GenerateExpiration: finished", "type", mw.typeStr, "err", err, "took", time.Since(then)) - }(time.Now()) - - mw.logger.Trace("database/GenerateExpiration: starting", "type", mw.typeStr) - } - return mw.next.GenerateExpiration(duration) -} - // ---- Metrics Middleware Domain ---- type databaseMetricsMiddleware struct { @@ -120,7 +87,7 @@ func (mw *databaseMetricsMiddleware) Type() string { return mw.next.Type() } -func (mw *databaseMetricsMiddleware) CreateUser(statements Statements, username, password, expiration string) (err error) { +func (mw *databaseMetricsMiddleware) CreateUser(statements Statements, usernamePrefix string, expiration time.Time) (username string, password string, err error) { defer func(now time.Time) { metrics.MeasureSince([]string{"database", "CreateUser"}, now) metrics.MeasureSince([]string{"database", mw.typeStr, "CreateUser"}, now) @@ -133,10 +100,10 @@ func (mw *databaseMetricsMiddleware) CreateUser(statements Statements, username, metrics.IncrCounter([]string{"database", "CreateUser"}, 1) metrics.IncrCounter([]string{"database", mw.typeStr, "CreateUser"}, 1) - return mw.next.CreateUser(statements, username, password, expiration) + return mw.next.CreateUser(statements, usernamePrefix, expiration) } -func (mw *databaseMetricsMiddleware) RenewUser(statements Statements, username, expiration string) (err error) { +func (mw *databaseMetricsMiddleware) RenewUser(statements Statements, username string, expiration time.Time) (err error) { defer func(now time.Time) { metrics.MeasureSince([]string{"database", "RenewUser"}, now) metrics.MeasureSince([]string{"database", mw.typeStr, "RenewUser"}, now) @@ -199,51 +166,3 @@ func (mw *databaseMetricsMiddleware) Close() (err error) { metrics.IncrCounter([]string{"database", mw.typeStr, "Close"}, 1) return mw.next.Close() } - -func (mw *databaseMetricsMiddleware) GenerateUsername(displayName string) (_ string, err error) { - defer func(now time.Time) { - metrics.MeasureSince([]string{"database", "GenerateUsername"}, now) - metrics.MeasureSince([]string{"database", mw.typeStr, "GenerateUsername"}, now) - - if err != nil { - metrics.IncrCounter([]string{"database", "GenerateUsername", "error"}, 1) - metrics.IncrCounter([]string{"database", mw.typeStr, "GenerateUsername", "error"}, 1) - } - }(time.Now()) - - metrics.IncrCounter([]string{"database", "GenerateUsername"}, 1) - metrics.IncrCounter([]string{"database", mw.typeStr, "GenerateUsername"}, 1) - return mw.next.GenerateUsername(displayName) -} - -func (mw *databaseMetricsMiddleware) GeneratePassword() (_ string, err error) { - defer func(now time.Time) { - metrics.MeasureSince([]string{"database", "GeneratePassword"}, now) - metrics.MeasureSince([]string{"database", mw.typeStr, "GeneratePassword"}, now) - - if err != nil { - metrics.IncrCounter([]string{"database", "GeneratePassword", "error"}, 1) - metrics.IncrCounter([]string{"database", mw.typeStr, "GeneratePassword", "error"}, 1) - } - }(time.Now()) - - metrics.IncrCounter([]string{"database", "GeneratePassword"}, 1) - metrics.IncrCounter([]string{"database", mw.typeStr, "GeneratePassword"}, 1) - return mw.next.GeneratePassword() -} - -func (mw *databaseMetricsMiddleware) GenerateExpiration(duration time.Duration) (_ string, err error) { - defer func(now time.Time) { - metrics.MeasureSince([]string{"database", "GenerateExpiration"}, now) - metrics.MeasureSince([]string{"database", mw.typeStr, "GenerateExpiration"}, now) - - if err != nil { - metrics.IncrCounter([]string{"database", "GenerateExpiration", "error"}, 1) - metrics.IncrCounter([]string{"database", mw.typeStr, "GenerateExpiration", "error"}, 1) - } - }(time.Now()) - - metrics.IncrCounter([]string{"database", "GenerateExpiration"}, 1) - metrics.IncrCounter([]string{"database", mw.typeStr, "GenerateExpiration"}, 1) - return mw.next.GenerateExpiration(duration) -} diff --git a/builtin/logical/database/dbplugin/plugin.go b/builtin/logical/database/dbplugin/plugin.go index 994f3b0ce95c..5cd24e8790d3 100644 --- a/builtin/logical/database/dbplugin/plugin.go +++ b/builtin/logical/database/dbplugin/plugin.go @@ -17,16 +17,12 @@ var ( // DatabaseType is the interface that all database objects must implement. type DatabaseType interface { Type() string - CreateUser(statements Statements, username, password, expiration string) error - RenewUser(statements Statements, username, expiration string) error + CreateUser(statements Statements, usernamePrefix string, expiration time.Time) (username string, password string, err error) + RenewUser(statements Statements, username string, expiration time.Time) error RevokeUser(statements Statements, username string) error Initialize(map[string]interface{}) error Close() error - - GenerateUsername(displayName string) (string, error) - GeneratePassword() (string, error) - GenerateExpiration(ttl time.Duration) (string, error) } // Statements set in role creation and passed into the database type's functions. @@ -96,16 +92,15 @@ func (DatabasePlugin) Client(b *plugin.MuxBroker, c *rpc.Client) (interface{}, e // ---- RPC Request Args Domain ---- type CreateUserRequest struct { - Statements Statements - Username string - Password string - Expiration string + Statements Statements + UsernamePrefix string + Expiration time.Time } type RenewUserRequest struct { Statements Statements Username string - Expiration string + Expiration time.Time } type RevokeUserRequest struct { @@ -115,12 +110,7 @@ type RevokeUserRequest struct { // ---- RPC Response Args Domain ---- -type GenerateUsernameResponse struct { +type CreateUserResponse struct { Username string -} -type GenerateExpirationResponse struct { - Expiration string -} -type GeneratePasswordResponse struct { Password string } diff --git a/builtin/logical/database/dbplugin/server.go b/builtin/logical/database/dbplugin/server.go index 018d9b8db1f6..2dddbaffda84 100644 --- a/builtin/logical/database/dbplugin/server.go +++ b/builtin/logical/database/dbplugin/server.go @@ -1,8 +1,6 @@ package dbplugin import ( - "time" - "github.com/hashicorp/go-plugin" "github.com/hashicorp/vault/helper/pluginutil" ) @@ -39,8 +37,9 @@ func (ds *databasePluginRPCServer) Type(_ struct{}, resp *string) error { return nil } -func (ds *databasePluginRPCServer) CreateUser(args *CreateUserRequest, _ *struct{}) error { - err := ds.impl.CreateUser(args.Statements, args.Username, args.Password, args.Expiration) +func (ds *databasePluginRPCServer) CreateUser(args *CreateUserRequest, resp *CreateUserResponse) error { + var err error + resp.Username, resp.Password, err = ds.impl.CreateUser(args.Statements, args.UsernamePrefix, args.Expiration) return err } @@ -67,24 +66,3 @@ func (ds *databasePluginRPCServer) Close(_ struct{}, _ *struct{}) error { ds.impl.Close() return nil } - -func (ds *databasePluginRPCServer) GenerateUsername(args string, resp *GenerateUsernameResponse) error { - var err error - resp.Username, err = ds.impl.GenerateUsername(args) - - return err -} - -func (ds *databasePluginRPCServer) GeneratePassword(_ struct{}, resp *GeneratePasswordResponse) error { - var err error - resp.Password, err = ds.impl.GeneratePassword() - - return err -} - -func (ds *databasePluginRPCServer) GenerateExpiration(args time.Duration, resp *GenerateExpirationResponse) error { - var err error - resp.Expiration, err = ds.impl.GenerateExpiration(args) - - return err -} diff --git a/builtin/logical/database/path_role_create.go b/builtin/logical/database/path_role_create.go index d379ef26739b..5a16c8926254 100644 --- a/builtin/logical/database/path_role_create.go +++ b/builtin/logical/database/path_role_create.go @@ -2,6 +2,7 @@ package database import ( "fmt" + "time" "github.com/hashicorp/vault/logical" "github.com/hashicorp/vault/logical/framework" @@ -48,24 +49,10 @@ func (b *databaseBackend) pathRoleCreateRead(req *logical.Request, data *framewo return nil, fmt.Errorf("cound not retrieve db with name: %s, got error: %s", role.DBName, err) } - // Generate the username, password and expiration - username, err := db.GenerateUsername(req.DisplayName) - if err != nil { - return nil, err - } - - password, err := db.GeneratePassword() - if err != nil { - return nil, err - } - - expiration, err := db.GenerateExpiration(role.DefaultTTL) - if err != nil { - return nil, err - } + expiration := time.Now().Add(role.DefaultTTL) // Create the user - err = db.CreateUser(role.Statements, username, password, expiration) + username, password, err := db.CreateUser(role.Statements, req.DisplayName, expiration) if err != nil { return nil, err } diff --git a/builtin/logical/database/secret_creds.go b/builtin/logical/database/secret_creds.go index 353541c0cc35..5701e373a6be 100644 --- a/builtin/logical/database/secret_creds.go +++ b/builtin/logical/database/secret_creds.go @@ -58,9 +58,7 @@ func (b *databaseBackend) secretCredsRenew(req *logical.Request, d *framework.Fi // Make sure we increase the VALID UNTIL endpoint for this user. if expireTime := resp.Secret.ExpirationTime(); !expireTime.IsZero() { - expiration := expireTime.Format("2006-01-02 15:04:05-0700") - - err := db.RenewUser(role.Statements, username, expiration) + err := db.RenewUser(role.Statements, username, expireTime) if err != nil { return nil, err } From 64efc505c8f34fbc2608c163bab4c4af61f1ec91 Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Mon, 10 Apr 2017 14:12:28 -0700 Subject: [PATCH 063/152] Update plugin test --- builtin/logical/database/backend_test.go | 2 +- .../logical/database/dbplugin/plugin_test.go | 189 ++++-------------- vault/testing.go | 13 +- 3 files changed, 48 insertions(+), 156 deletions(-) diff --git a/builtin/logical/database/backend_test.go b/builtin/logical/database/backend_test.go index fc41cf3cd83c..5b3a0db42abe 100644 --- a/builtin/logical/database/backend_test.go +++ b/builtin/logical/database/backend_test.go @@ -81,7 +81,7 @@ func getCore(t *testing.T) (*vault.Core, net.Listener, logical.SystemView, strin core, _, token, ln := vault.TestCoreUnsealedWithListener(t) http.TestServerWithListener(t, ln, "", core) sys := vault.TestDynamicSystemView(core) - vault.TestAddTestPlugin(t, core, "postgresql-database-plugin", fmt.Sprintf("%s -test.run=TestBackend_PluginMain", os.Args[0])) + vault.TestAddTestPlugin(t, core, "postgresql-database-plugin", "TestBackend_PluginMain") return core, ln, sys, token } diff --git a/builtin/logical/database/dbplugin/plugin_test.go b/builtin/logical/database/dbplugin/plugin_test.go index 849e1ebbf463..7909bbd4e53d 100644 --- a/builtin/logical/database/dbplugin/plugin_test.go +++ b/builtin/logical/database/dbplugin/plugin_test.go @@ -1,17 +1,13 @@ -package dbplugin +package dbplugin_test import ( - "crypto/sha256" - "encoding/hex" "errors" - "fmt" - "io" "net" "os" - "os/exec" "testing" "time" + "github.com/hashicorp/vault/builtin/logical/database/dbplugin" "github.com/hashicorp/vault/helper/pluginutil" "github.com/hashicorp/vault/http" "github.com/hashicorp/vault/logical" @@ -21,27 +17,26 @@ import ( type mockPlugin struct { users map[string][]string - CredentialsProducer } func (m *mockPlugin) Type() string { return "mock" } -func (m *mockPlugin) CreateUser(statements Statements, username, password, expiration string) error { - err := errors.New("err") - if username == "" || password == "" || expiration == "" { - return err +func (m *mockPlugin) CreateUser(statements dbplugin.Statements, usernamePrefix string, expiration time.Time) (username string, password string, err error) { + err = errors.New("err") + if usernamePrefix == "" || expiration.IsZero() { + return "", "", err } - if _, ok := m.users[username]; ok { - return err + if _, ok := m.users[usernamePrefix]; ok { + return "", "", err } - m.users[username] = []string{password, expiration} + m.users[usernamePrefix] = []string{password} - return nil + return usernamePrefix, "test", nil } -func (m *mockPlugin) RenewUser(statements Statements, username, expiration string) error { +func (m *mockPlugin) RenewUser(statements dbplugin.Statements, username string, expiration time.Time) error { err := errors.New("err") - if username == "" || expiration == "" { + if username == "" || expiration.IsZero() { return err } @@ -51,7 +46,7 @@ func (m *mockPlugin) RenewUser(statements Statements, username, expiration strin return nil } -func (m *mockPlugin) RevokeUser(statements Statements, username string) error { +func (m *mockPlugin) RevokeUser(statements dbplugin.Statements, username string) error { err := errors.New("err") if username == "" { return err @@ -77,40 +72,11 @@ func (m *mockPlugin) Close() error { return nil } -func getConf(t *testing.T) *DatabaseConfig { - command := fmt.Sprintf("%s -test.run=TestPlugin_Main", os.Args[0]) - cmd := exec.Command(os.Args[0]) - hash := sha256.New() - - file, err := os.Open(cmd.Path) - if err != nil { - t.Fatal(err) - } - defer file.Close() - - _, err = io.Copy(hash, file) - if err != nil { - t.Fatal(err) - } - - sum := hash.Sum(nil) - - conf := &DatabaseConfig{ - DatabaseType: pluginTypeName, - PluginCommand: command, - PluginChecksum: hex.EncodeToString(sum), - ConnectionDetails: map[string]interface{}{ - "test": true, - }, - } - - return conf -} - func getCore(t *testing.T) (*vault.Core, net.Listener, logical.SystemView) { core, _, _, ln := vault.TestCoreUnsealedWithListener(t) http.TestServerWithListener(t, ln, "", core) sys := vault.TestDynamicSystemView(core) + vault.TestAddTestPlugin(t, core, "test-plugin", "TestPlugin_Main") return core, ln, sys } @@ -123,24 +89,26 @@ func TestPlugin_Main(t *testing.T) { } plugin := &mockPlugin{ - users: make(map[string][]string), - CredentialsProducer: &sqlCredentialsProducer{5, 50}, + users: make(map[string][]string), } - NewPluginServer(plugin) + dbplugin.NewPluginServer(plugin) } func TestPlugin_Initialize(t *testing.T) { _, ln, sys := getCore(t) defer ln.Close() - conf := getConf(t) - dbRaw, err := PluginFactory(conf, sys, &log.NullLogger{}) + dbRaw, err := dbplugin.PluginFactory("test-plugin", sys, &log.NullLogger{}) if err != nil { t.Fatalf("err: %s", err) } - err = dbRaw.Initialize(conf.ConnectionDetails) + connectionDetails := map[string]interface{}{ + "test": 1, + } + + err = dbRaw.Initialize(connectionDetails) if err != nil { t.Fatalf("err: %s", err) } @@ -155,97 +123,61 @@ func TestPlugin_CreateUser(t *testing.T) { _, ln, sys := getCore(t) defer ln.Close() - conf := getConf(t) - db, err := PluginFactory(conf, sys, &log.NullLogger{}) + db, err := dbplugin.PluginFactory("test-plugin", sys, &log.NullLogger{}) if err != nil { t.Fatalf("err: %s", err) } defer db.Close() - err = db.Initialize(conf.ConnectionDetails) - if err != nil { - t.Fatalf("err: %s", err) + connectionDetails := map[string]interface{}{ + "test": 1, } - username, err := db.GenerateUsername("test") + err = db.Initialize(connectionDetails) if err != nil { t.Fatalf("err: %s", err) } - password, err := db.GeneratePassword() + us, pw, err := db.CreateUser(dbplugin.Statements{}, "test", time.Now().Add(time.Minute)) if err != nil { t.Fatalf("err: %s", err) } - - expiration, err := db.GenerateExpiration(time.Minute) - if err != nil { - t.Fatalf("err: %s", err) + if us != "test" || pw != "test" { + t.Fatal("expected username and password to be 'test'") } - err = db.CreateUser(Statements{}, username, password, expiration) - if err != nil { - t.Fatalf("err: %s", err) - } // try and save the same user again to verify it saved the first time, this // should return an error - err = db.CreateUser(Statements{}, username, password, expiration) + _, _, err = db.CreateUser(dbplugin.Statements{}, "test", time.Now().Add(time.Minute)) if err == nil { t.Fatal("expected an error, user wasn't created correctly") } - - // Create one more user - username, err = db.GenerateUsername("test") - if err != nil { - t.Fatalf("err: %s", err) - } - - err = db.CreateUser(Statements{}, username, password, expiration) - if err != nil { - t.Fatalf("err: %s", err) - } } func TestPlugin_RenewUser(t *testing.T) { _, ln, sys := getCore(t) defer ln.Close() - conf := getConf(t) - db, err := PluginFactory(conf, sys, &log.NullLogger{}) + db, err := dbplugin.PluginFactory("test-plugin", sys, &log.NullLogger{}) if err != nil { t.Fatalf("err: %s", err) } defer db.Close() - err = db.Initialize(conf.ConnectionDetails) - if err != nil { - t.Fatalf("err: %s", err) - } - - username, err := db.GenerateUsername("test") - if err != nil { - t.Fatalf("err: %s", err) - } - - password, err := db.GeneratePassword() - if err != nil { - t.Fatalf("err: %s", err) + connectionDetails := map[string]interface{}{ + "test": 1, } - - expiration, err := db.GenerateExpiration(time.Minute) + err = db.Initialize(connectionDetails) if err != nil { t.Fatalf("err: %s", err) } - err = db.CreateUser(Statements{}, username, password, expiration) + us, _, err := db.CreateUser(dbplugin.Statements{}, "test", time.Now().Add(time.Minute)) if err != nil { t.Fatalf("err: %s", err) } - expiration, err = db.GenerateExpiration(time.Minute) - if err != nil { - t.Fatalf("err: %s", err) - } - err = db.RenewUser(Statements{}, username, expiration) + err = db.RenewUser(dbplugin.Statements{}, us, time.Now().Add(time.Minute)) if err != nil { t.Fatalf("err: %s", err) } @@ -255,69 +187,34 @@ func TestPlugin_RevokeUser(t *testing.T) { _, ln, sys := getCore(t) defer ln.Close() - conf := getConf(t) - db, err := PluginFactory(conf, sys, &log.NullLogger{}) + db, err := dbplugin.PluginFactory("test-plugin", sys, &log.NullLogger{}) if err != nil { t.Fatalf("err: %s", err) } defer db.Close() - err = db.Initialize(conf.ConnectionDetails) - if err != nil { - t.Fatalf("err: %s", err) - } - - username, err := db.GenerateUsername("test") - if err != nil { - t.Fatalf("err: %s", err) - } - - password, err := db.GeneratePassword() - if err != nil { - t.Fatalf("err: %s", err) + connectionDetails := map[string]interface{}{ + "test": 1, } - - expiration, err := db.GenerateExpiration(time.Minute) + err = db.Initialize(connectionDetails) if err != nil { t.Fatalf("err: %s", err) } - err = db.CreateUser(Statements{}, username, password, expiration) + us, _, err := db.CreateUser(dbplugin.Statements{}, "test", time.Now().Add(time.Minute)) if err != nil { t.Fatalf("err: %s", err) } // Test default revoke statememts - err = db.RevokeUser(Statements{}, username) + err = db.RevokeUser(dbplugin.Statements{}, us) if err != nil { t.Fatalf("err: %s", err) } // Try adding the same username back so we can verify it was removed - err = db.CreateUser(Statements{}, username, password, expiration) - if err != nil { - t.Fatalf("err: %s", err) - } - - username, err = db.GenerateUsername("test") - if err != nil { - t.Fatalf("err: %s", err) - } - - expiration, err = db.GenerateExpiration(time.Minute) + _, _, err = db.CreateUser(dbplugin.Statements{}, "test", time.Now().Add(time.Minute)) if err != nil { t.Fatalf("err: %s", err) } - - // try once more - err = db.CreateUser(Statements{}, username, password, expiration) - if err != nil { - t.Fatalf("err: %s", err) - } - - err = db.RevokeUser(Statements{}, username) - if err != nil { - t.Fatalf("err: %s", err) - } - } diff --git a/vault/testing.go b/vault/testing.go index fdf55b4e595a..b2fe36b332f6 100644 --- a/vault/testing.go +++ b/vault/testing.go @@ -14,7 +14,6 @@ import ( "os" "os/exec" "path/filepath" - "strings" "testing" "time" @@ -320,10 +319,8 @@ func TestDynamicSystemView(c *Core) *dynamicSystemView { return &dynamicSystemView{c, me} } -func TestAddTestPlugin(t testing.TB, c *Core, name, command string) { - parts := strings.Split(command, " ") - - file, err := os.Open(parts[0]) +func TestAddTestPlugin(t testing.TB, c *Core, name, testFunc string) { + file, err := os.Open(os.Args[0]) if err != nil { t.Fatal(err) } @@ -337,15 +334,13 @@ func TestAddTestPlugin(t testing.TB, c *Core, name, command string) { } sum := hash.Sum(nil) - c.pluginCatalog.directory, err = filepath.EvalSymlinks(parts[0]) + c.pluginCatalog.directory, err = filepath.EvalSymlinks(os.Args[0]) if err != nil { t.Fatal(err) } c.pluginCatalog.directory = filepath.Dir(c.pluginCatalog.directory) - parts[0] = filepath.Base(parts[0]) - command = strings.Join(parts, " ") - + command := fmt.Sprintf("%s --test.run=%s", filepath.Base(os.Args[0]), testFunc) err = c.pluginCatalog.Set(name, command, sum) if err != nil { t.Fatal(err) From f54c4de98addaf809d7ed8ee5d6f8a3fa99216a9 Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Mon, 10 Apr 2017 15:36:59 -0700 Subject: [PATCH 064/152] Add a flag to tell plugins to verify the connection was successful --- builtin/logical/database/backend.go | 2 +- builtin/logical/database/dbplugin/client.go | 9 +++++++-- .../database/dbplugin/databasemiddleware.go | 10 +++++----- builtin/logical/database/dbplugin/plugin.go | 8 ++++++-- builtin/logical/database/dbplugin/server.go | 4 ++-- .../logical/database/path_config_connection.go | 15 ++++----------- 6 files changed, 25 insertions(+), 23 deletions(-) diff --git a/builtin/logical/database/backend.go b/builtin/logical/database/backend.go index baa05a0923f0..4cf542d959f4 100644 --- a/builtin/logical/database/backend.go +++ b/builtin/logical/database/backend.go @@ -106,7 +106,7 @@ func (b *databaseBackend) getOrCreateDBObj(s logical.Storage, name string) (dbpl return nil, err } - err = db.Initialize(config.ConnectionDetails) + err = db.Initialize(config.ConnectionDetails, true) if err != nil { return nil, err } diff --git a/builtin/logical/database/dbplugin/client.go b/builtin/logical/database/dbplugin/client.go index 0dae61d27647..da39ed425bff 100644 --- a/builtin/logical/database/dbplugin/client.go +++ b/builtin/logical/database/dbplugin/client.go @@ -114,8 +114,13 @@ func (dr *databasePluginRPCClient) RevokeUser(statements Statements, username st return err } -func (dr *databasePluginRPCClient) Initialize(conf map[string]interface{}) error { - err := dr.client.Call("Plugin.Initialize", conf, &struct{}{}) +func (dr *databasePluginRPCClient) Initialize(conf map[string]interface{}, verifyConnection bool) error { + req := InitializeRequest{ + Config: conf, + VerifyConnection: verifyConnection, + } + + err := dr.client.Call("Plugin.Initialize", req, &struct{}{}) return err } diff --git a/builtin/logical/database/dbplugin/databasemiddleware.go b/builtin/logical/database/dbplugin/databasemiddleware.go index 2748f2f11f47..1df7be3bb5c0 100644 --- a/builtin/logical/database/dbplugin/databasemiddleware.go +++ b/builtin/logical/database/dbplugin/databasemiddleware.go @@ -53,15 +53,15 @@ func (mw *databaseTracingMiddleware) RevokeUser(statements Statements, username return mw.next.RevokeUser(statements, username) } -func (mw *databaseTracingMiddleware) Initialize(conf map[string]interface{}) (err error) { +func (mw *databaseTracingMiddleware) Initialize(conf map[string]interface{}, verifyConnection bool) (err error) { if mw.logger.IsTrace() { defer func(then time.Time) { - mw.logger.Trace("database/Initialize: finished", "type", mw.typeStr, "err", err, "took", time.Since(then)) + mw.logger.Trace("database/Initialize: finished", "type", mw.typeStr, "verify", verifyConnection, "err", err, "took", time.Since(then)) }(time.Now()) mw.logger.Trace("database/Initialize: starting", "type", mw.typeStr) } - return mw.next.Initialize(conf) + return mw.next.Initialize(conf, verifyConnection) } func (mw *databaseTracingMiddleware) Close() (err error) { @@ -135,7 +135,7 @@ func (mw *databaseMetricsMiddleware) RevokeUser(statements Statements, username return mw.next.RevokeUser(statements, username) } -func (mw *databaseMetricsMiddleware) Initialize(conf map[string]interface{}) (err error) { +func (mw *databaseMetricsMiddleware) Initialize(conf map[string]interface{}, verifyConnection bool) (err error) { defer func(now time.Time) { metrics.MeasureSince([]string{"database", "Initialize"}, now) metrics.MeasureSince([]string{"database", mw.typeStr, "Initialize"}, now) @@ -148,7 +148,7 @@ func (mw *databaseMetricsMiddleware) Initialize(conf map[string]interface{}) (er metrics.IncrCounter([]string{"database", "Initialize"}, 1) metrics.IncrCounter([]string{"database", mw.typeStr, "Initialize"}, 1) - return mw.next.Initialize(conf) + return mw.next.Initialize(conf, verifyConnection) } func (mw *databaseMetricsMiddleware) Close() (err error) { diff --git a/builtin/logical/database/dbplugin/plugin.go b/builtin/logical/database/dbplugin/plugin.go index 5cd24e8790d3..39655bf4657a 100644 --- a/builtin/logical/database/dbplugin/plugin.go +++ b/builtin/logical/database/dbplugin/plugin.go @@ -21,12 +21,11 @@ type DatabaseType interface { RenewUser(statements Statements, username string, expiration time.Time) error RevokeUser(statements Statements, username string) error - Initialize(map[string]interface{}) error + Initialize(config map[string]interface{}, verifyConnection bool) error Close() error } // Statements set in role creation and passed into the database type's functions. -// TODO: Add a way of setting defaults here. type Statements struct { CreationStatements string `json:"creation_statments" mapstructure:"creation_statements" structs:"creation_statments"` RevocationStatements string `json:"revocation_statements" mapstructure:"revocation_statements" structs:"revocation_statements"` @@ -91,6 +90,11 @@ func (DatabasePlugin) Client(b *plugin.MuxBroker, c *rpc.Client) (interface{}, e // ---- RPC Request Args Domain ---- +type InitializeRequest struct { + Config map[string]interface{} + VerifyConnection bool +} + type CreateUserRequest struct { Statements Statements UsernamePrefix string diff --git a/builtin/logical/database/dbplugin/server.go b/builtin/logical/database/dbplugin/server.go index 2dddbaffda84..54b05338c983 100644 --- a/builtin/logical/database/dbplugin/server.go +++ b/builtin/logical/database/dbplugin/server.go @@ -56,8 +56,8 @@ func (ds *databasePluginRPCServer) RevokeUser(args *RevokeUserRequest, _ *struct return err } -func (ds *databasePluginRPCServer) Initialize(args map[string]interface{}, _ *struct{}) error { - err := ds.impl.Initialize(args) +func (ds *databasePluginRPCServer) Initialize(args *InitializeRequest, _ *struct{}) error { + err := ds.impl.Initialize(args.Config, args.VerifyConnection) return err } diff --git a/builtin/logical/database/path_config_connection.go b/builtin/logical/database/path_config_connection.go index 7589669a4625..8e78aa425e15 100644 --- a/builtin/logical/database/path_config_connection.go +++ b/builtin/logical/database/path_config_connection.go @@ -2,7 +2,6 @@ package database import ( "fmt" - "strings" "github.com/fatih/structs" "github.com/hashicorp/vault/builtin/logical/database/dbplugin" @@ -169,23 +168,17 @@ func (b *databaseBackend) connectionWriteHandler() framework.OperationFunc { return logical.ErrorResponse(fmt.Sprintf("Error creating database object: %s", err)), nil } - err = db.Initialize(config.ConnectionDetails) + err = db.Initialize(config.ConnectionDetails, verifyConnection) if err != nil { - if !strings.Contains(err.Error(), "error initalizing connection") { - db.Close() - return logical.ErrorResponse(fmt.Sprintf("Error creating database object: %s", err)), nil - } - - if verifyConnection { - db.Close() - return logical.ErrorResponse("Could not verify connection"), nil - } + db.Close() + return logical.ErrorResponse(fmt.Sprintf("Error creating database object: %s", err)), nil } if _, ok := b.connections[name]; ok { // Close and remove the old connection err := b.connections[name].Close() if err != nil { + db.Close() return nil, err } From de36d61e5af5e072d5fb8cfb6ee8076cff617a49 Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Mon, 10 Apr 2017 17:12:52 -0700 Subject: [PATCH 065/152] Mlock the plugin process --- builtin/logical/database/backend.go | 2 ++ builtin/logical/database/dbplugin/server.go | 8 ++++++ helper/pluginutil/runner.go | 32 ++++++++++++++++++++- helper/pluginutil/tls.go | 4 --- logical/system_view.go | 10 +++++++ vault/core.go | 5 +++- vault/dynamic_system_view.go | 7 +++++ vault/plugin_catalog.go | 2 ++ 8 files changed, 64 insertions(+), 6 deletions(-) diff --git a/builtin/logical/database/backend.go b/builtin/logical/database/backend.go index 4cf542d959f4..618ffac6f809 100644 --- a/builtin/logical/database/backend.go +++ b/builtin/logical/database/backend.go @@ -76,6 +76,8 @@ func (b *databaseBackend) closeAllDBs() { for _, db := range b.connections { db.Close() } + + b.connections = nil } // This function is used to retrieve a database object either from the cached diff --git a/builtin/logical/database/dbplugin/server.go b/builtin/logical/database/dbplugin/server.go index 54b05338c983..5c1b41a3d1d3 100644 --- a/builtin/logical/database/dbplugin/server.go +++ b/builtin/logical/database/dbplugin/server.go @@ -1,6 +1,8 @@ package dbplugin import ( + "fmt" + "github.com/hashicorp/go-plugin" "github.com/hashicorp/vault/helper/pluginutil" ) @@ -18,6 +20,12 @@ func NewPluginServer(db DatabaseType) { "database": dbPlugin, } + err := pluginutil.OptionallyEnableMlock() + if err != nil { + fmt.Println(err) + return + } + plugin.Serve(&plugin.ServeConfig{ HandshakeConfig: handshakeConfig, Plugins: pluginMap, diff --git a/helper/pluginutil/runner.go b/helper/pluginutil/runner.go index 90569dd9ad1f..4d66d8706bce 100644 --- a/helper/pluginutil/runner.go +++ b/helper/pluginutil/runner.go @@ -3,15 +3,29 @@ package pluginutil import ( "crypto/sha256" "fmt" + "os" "os/exec" + "time" plugin "github.com/hashicorp/go-plugin" + "github.com/hashicorp/vault/helper/mlock" +) + +var ( + // PluginUnwrapTokenEnv is the ENV name used to pass unwrap tokens to the + // plugin. + PluginMlockEnabled = "VAULT_PLUGIN_MLOCK_ENABLED" ) type Looker interface { LookupPlugin(string) (*PluginRunner, error) } +type Wrapper interface { + ResponseWrapData(data map[string]interface{}, ttl time.Duration, jwt bool) (string, error) + MlockDisabled() bool +} + type LookWrapper interface { Looker Wrapper @@ -22,6 +36,7 @@ type PluginRunner struct { Command string `json:"command"` Args []string `json:"args"` Sha256 []byte `json:"sha256"` + Builtin bool `json:"builtin"` } func (r *PluginRunner) Run(wrapper Wrapper, pluginMap map[string]plugin.Plugin, hs plugin.HandshakeConfig, env []string) (*plugin.Client, error) { @@ -44,10 +59,17 @@ func (r *PluginRunner) Run(wrapper Wrapper, pluginMap map[string]plugin.Plugin, return nil, err } - // Add the response wrap token to the ENV of the plugin + mlock := "true" + if wrapper.MlockDisabled() { + mlock = "false" + } + cmd := exec.Command(r.Command, r.Args...) cmd.Env = append(cmd.Env, env...) + // Add the response wrap token to the ENV of the plugin cmd.Env = append(cmd.Env, fmt.Sprintf("%s=%s", PluginUnwrapTokenEnv, wrapToken)) + // Add the mlock setting to the ENV of the plugin + cmd.Env = append(cmd.Env, fmt.Sprintf("%s=%s", PluginMlockEnabled, mlock)) secureConfig := &plugin.SecureConfig{ Checksum: r.Sha256, @@ -64,3 +86,11 @@ func (r *PluginRunner) Run(wrapper Wrapper, pluginMap map[string]plugin.Plugin, return client, nil } + +func OptionallyEnableMlock() error { + if os.Getenv(PluginMlockEnabled) == "true" { + return mlock.LockMemory() + } + + return nil +} diff --git a/helper/pluginutil/tls.go b/helper/pluginutil/tls.go index 63ae2932f172..c7aa42ee608e 100644 --- a/helper/pluginutil/tls.go +++ b/helper/pluginutil/tls.go @@ -29,10 +29,6 @@ var ( PluginUnwrapTokenEnv = "VAULT_UNWRAP_TOKEN" ) -type Wrapper interface { - ResponseWrapData(data map[string]interface{}, ttl time.Duration, jwt bool) (string, error) -} - // GenerateCACert returns a CA cert used to later sign the certificates for the // plugin client and server. func GenerateCACert() ([]byte, *x509.Certificate, *ecdsa.PrivateKey, error) { diff --git a/logical/system_view.go b/logical/system_view.go index a9626bc50ee6..b69f2709014e 100644 --- a/logical/system_view.go +++ b/logical/system_view.go @@ -44,7 +44,12 @@ type SystemView interface { // token used to unwrap. ResponseWrapData(data map[string]interface{}, ttl time.Duration, jwt bool) (string, error) + // LookupPlugin looks into the plugin catalog for a plugin with the given + // name. Returns a PluginRunner or an error if a plugin can not be found. LookupPlugin(string) (*pluginutil.PluginRunner, error) + + // MlockDisabled returns the configuration setting for DisableMlock. + MlockDisabled() bool } type StaticSystemView struct { @@ -54,6 +59,7 @@ type StaticSystemView struct { TaintedVal bool CachingDisabledVal bool Primary bool + DisableMlock bool ReplicationStateVal consts.ReplicationState } @@ -88,3 +94,7 @@ func (d StaticSystemView) ResponseWrapData(data map[string]interface{}, ttl time func (d StaticSystemView) LookupPlugin(name string) (*pluginutil.PluginRunner, error) { return nil, errors.New("LookupPlugin is not implimented in StaticSystemView") } + +func (d StaticSystemView) MlockDisabled() bool { + return d.DisableMlock +} diff --git a/vault/core.go b/vault/core.go index ffd36683be59..9a2f1900ef8d 100644 --- a/vault/core.go +++ b/vault/core.go @@ -332,7 +332,7 @@ type Core struct { // uiEnabled indicates whether Vault Web UI is enabled or not uiEnabled bool - // pluginDirectory is the location vault will look for plugins + // pluginDirectory is the location vault will look for plugin binaries pluginDirectory string // vaultBinaryLocation is used to run builtin plugins in secure mode @@ -343,6 +343,8 @@ type Core struct { // pluginCatalog is used to manage plugin configurations pluginCatalog *PluginCatalog + + disableMlock bool } // CoreConfig is used to parameterize a core @@ -449,6 +451,7 @@ func NewCore(conf *CoreConfig) (*Core, error) { clusterListenerShutdownSuccessCh: make(chan struct{}), vaultBinaryLocation: conf.VaultBinaryLocation, vaultBinarySHA256: conf.VaultBinarySHA256, + disableMlock: conf.DisableMlock, } // Wrap the physical backend in a cache layer if enabled and not already wrapped diff --git a/vault/dynamic_system_view.go b/vault/dynamic_system_view.go index f318f3ab13cd..ca2b89d6c881 100644 --- a/vault/dynamic_system_view.go +++ b/vault/dynamic_system_view.go @@ -116,6 +116,13 @@ func (d dynamicSystemView) ResponseWrapData(data map[string]interface{}, ttl tim return resp.WrapInfo.Token, nil } +// LookupPlugin looks for a plugin with the given name in the plugin catalog. It +// returns a PluginRunner or an error if no plugin was found. func (d dynamicSystemView) LookupPlugin(name string) (*pluginutil.PluginRunner, error) { return d.core.pluginCatalog.Get(name) } + +// MlockDisabled returns the configuration setting "DisableMlock". +func (d dynamicSystemView) MlockDisabled() bool { + return d.core.disableMlock +} diff --git a/vault/plugin_catalog.go b/vault/plugin_catalog.go index a42f85ec115a..737f0c26b91d 100644 --- a/vault/plugin_catalog.go +++ b/vault/plugin_catalog.go @@ -63,6 +63,7 @@ func (c *PluginCatalog) Get(name string) (*pluginutil.PluginRunner, error) { Command: c.vaultCommand, Args: []string{"plugin-exec", name}, Sha256: c.vaultSHA256, + Builtin: true, }, nil } @@ -93,6 +94,7 @@ func (c *PluginCatalog) Set(name, command string, sha256 []byte) error { Command: command, Args: args, Sha256: sha256, + Builtin: false, } buf, err := json.Marshal(entry) From da4d9a8b4fe9a8eda3492b561b12293758936af5 Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Mon, 10 Apr 2017 18:38:34 -0700 Subject: [PATCH 066/152] Remove unnecessary abstraction --- .../logical/database/path_config_connection.go | 15 ++++----------- 1 file changed, 4 insertions(+), 11 deletions(-) diff --git a/builtin/logical/database/path_config_connection.go b/builtin/logical/database/path_config_connection.go index 8e78aa425e15..c242aa33982a 100644 --- a/builtin/logical/database/path_config_connection.go +++ b/builtin/logical/database/path_config_connection.go @@ -51,15 +51,8 @@ func (b *databaseBackend) pathConnectionReset(req *logical.Request, data *framew // pathConfigurePluginConnection returns a configured framework.Path setup to // operate on plugins. func pathConfigurePluginConnection(b *databaseBackend) *framework.Path { - return buildConfigConnectionPath("config/%s", b.connectionWriteHandler(), b.connectionReadHandler(), b.connectionDeleteHandler()) -} - -// buildConfigConnectionPath reutns a configured framework.Path using the passed -// in operation functions to complete the request. Used to distinguish calls -// between builtin and plugin databases. -func buildConfigConnectionPath(path string, updateOp, readOp, deleteOp framework.OperationFunc) *framework.Path { return &framework.Path{ - Pattern: fmt.Sprintf(path, framework.GenericNameRegex("name")), + Pattern: fmt.Sprintf("config/%s", framework.GenericNameRegex("name")), Fields: map[string]*framework.FieldSchema{ "name": &framework.FieldSchema{ Type: framework.TypeString, @@ -80,9 +73,9 @@ func buildConfigConnectionPath(path string, updateOp, readOp, deleteOp framework }, Callbacks: map[logical.Operation]framework.OperationFunc{ - logical.UpdateOperation: updateOp, - logical.ReadOperation: readOp, - logical.DeleteOperation: deleteOp, + logical.UpdateOperation: b.connectionWriteHandler(), + logical.ReadOperation: b.connectionReadHandler(), + logical.DeleteOperation: b.connectionDeleteHandler(), }, HelpSynopsis: pathConfigConnectionHelpSyn, From 8f75c3031170879b78694e1dcc8c7b0b90e31b04 Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Tue, 11 Apr 2017 11:50:34 -0700 Subject: [PATCH 067/152] Update help text and comments --- builtin/logical/database/dbplugin/client.go | 4 +- .../database/dbplugin/databasemiddleware.go | 4 + builtin/logical/database/dbplugin/plugin.go | 4 + builtin/logical/database/dbplugin/server.go | 5 +- .../database/path_config_connection.go | 81 ++++--- builtin/logical/database/path_role_create.go | 70 +++--- builtin/logical/database/path_roles.go | 214 ++++++++++-------- builtin/logical/database/secret_creds.go | 166 +++++++------- logical/system_view.go | 4 +- 9 files changed, 302 insertions(+), 250 deletions(-) diff --git a/builtin/logical/database/dbplugin/client.go b/builtin/logical/database/dbplugin/client.go index da39ed425bff..5bdc3a01a0db 100644 --- a/builtin/logical/database/dbplugin/client.go +++ b/builtin/logical/database/dbplugin/client.go @@ -10,7 +10,7 @@ import ( "github.com/hashicorp/vault/helper/pluginutil" ) -// DatabasePluginClient embeds a databasePluginRPCClient and wraps it's close +// DatabasePluginClient embeds a databasePluginRPCClient and wraps it's Close // method to also call Kill() on the plugin.Client. type DatabasePluginClient struct { client *plugin.Client @@ -64,7 +64,7 @@ func newPluginClient(sys pluginutil.Wrapper, pluginRunner *pluginutil.PluginRunn // ---- RPC client domain ---- -// databasePluginRPCClient impliments DatabaseType and is used on the client to +// databasePluginRPCClient implements DatabaseType and is used on the client to // make RPC calls to a plugin. type databasePluginRPCClient struct { client *rpc.Client diff --git a/builtin/logical/database/dbplugin/databasemiddleware.go b/builtin/logical/database/dbplugin/databasemiddleware.go index 1df7be3bb5c0..2137cd9c388b 100644 --- a/builtin/logical/database/dbplugin/databasemiddleware.go +++ b/builtin/logical/database/dbplugin/databasemiddleware.go @@ -9,6 +9,8 @@ import ( // ---- Tracing Middleware Domain ---- +// databaseTracingMiddleware wraps a implementation of DatabaseType and executes +// trace logging on function call. type databaseTracingMiddleware struct { next DatabaseType logger log.Logger @@ -77,6 +79,8 @@ func (mw *databaseTracingMiddleware) Close() (err error) { // ---- Metrics Middleware Domain ---- +// databaseMetricsMiddleware wraps an implementation of DatabaseTypes and on +// function call logs metrics about this instance. type databaseMetricsMiddleware struct { next DatabaseType diff --git a/builtin/logical/database/dbplugin/plugin.go b/builtin/logical/database/dbplugin/plugin.go index 39655bf4657a..dadb6639eea4 100644 --- a/builtin/logical/database/dbplugin/plugin.go +++ b/builtin/logical/database/dbplugin/plugin.go @@ -40,11 +40,13 @@ func PluginFactory(pluginName string, sys pluginutil.LookWrapper, logger log.Log return nil, ErrEmptyPluginName } + // Look for plugin in the plugin catalog pluginMeta, err := sys.LookupPlugin(pluginName) if err != nil { return nil, err } + // create a DatabasePluginClient instance db, err := newPluginClient(sys, pluginMeta) if err != nil { return nil, err @@ -76,6 +78,8 @@ var handshakeConfig = plugin.HandshakeConfig{ MagicCookieValue: "926a0820-aea2-be28-51d6-83cdf00e8edb", } +// DatabasePlugin implements go-plugin's Plugin interface. It has methods for +// retrieving a server and a client instance of the plugin. type DatabasePlugin struct { impl DatabaseType } diff --git a/builtin/logical/database/dbplugin/server.go b/builtin/logical/database/dbplugin/server.go index 5c1b41a3d1d3..326e25103cff 100644 --- a/builtin/logical/database/dbplugin/server.go +++ b/builtin/logical/database/dbplugin/server.go @@ -8,7 +8,7 @@ import ( ) // NewPluginServer is called from within a plugin and wraps the provided -// DatabaseType implimentation in a databasePluginRPCServer object and starts a +// DatabaseType implementation in a databasePluginRPCServer object and starts a // RPC server. func NewPluginServer(db DatabaseType) { dbPlugin := &DatabasePlugin{ @@ -35,7 +35,8 @@ func NewPluginServer(db DatabaseType) { // ---- RPC server domain ---- -// databasePluginRPCServer impliments DatabaseType and is run inside a plugin +// databasePluginRPCServer implements an RPC version of DatabaseType and is run +// inside a plugin. It wraps an underlying implementation of DatabaseType. type databasePluginRPCServer struct { impl DatabaseType } diff --git a/builtin/logical/database/path_config_connection.go b/builtin/logical/database/path_config_connection.go index c242aa33982a..5817f53c2d8c 100644 --- a/builtin/logical/database/path_config_connection.go +++ b/builtin/logical/database/path_config_connection.go @@ -9,6 +9,7 @@ import ( "github.com/hashicorp/vault/logical/framework" ) +// pathResetConnection configures a path to reset a plugin. func pathResetConnection(b *databaseBackend) *framework.Path { return &framework.Path{ Pattern: fmt.Sprintf("reset/%s", framework.GenericNameRegex("name")), @@ -20,32 +21,36 @@ func pathResetConnection(b *databaseBackend) *framework.Path { }, Callbacks: map[logical.Operation]framework.OperationFunc{ - logical.UpdateOperation: b.pathConnectionReset, + logical.UpdateOperation: b.pathConnectionReset(), }, - HelpSynopsis: pathConfigConnectionHelpSyn, - HelpDescription: pathConfigConnectionHelpDesc, + HelpSynopsis: pathResetConnectionHelpSyn, + HelpDescription: pathResetConnectionHelpDesc, } } -func (b *databaseBackend) pathConnectionReset(req *logical.Request, data *framework.FieldData) (*logical.Response, error) { - name := data.Get("name").(string) - if name == "" { - return logical.ErrorResponse("Empty name attribute given"), nil - } +// pathConnectionReset resets a plugin by closing the existing instance and +// creating a new one. +func (b *databaseBackend) pathConnectionReset() framework.OperationFunc { + return func(req *logical.Request, data *framework.FieldData) (*logical.Response, error) { + name := data.Get("name").(string) + if name == "" { + return logical.ErrorResponse("Empty name attribute given"), nil + } + + // Grab the mutex lock + b.Lock() + defer b.Unlock() - // Grab the mutex lock - b.Lock() - defer b.Unlock() + b.clearConnection(name) - b.clearConnection(name) + _, err := b.getOrCreateDBObj(req.Storage, name) + if err != nil { + return nil, err + } - _, err := b.getOrCreateDBObj(req.Storage, name) - if err != nil { - return nil, err + return nil, nil } - - return nil, nil } // pathConfigurePluginConnection returns a configured framework.Path setup to @@ -60,15 +65,17 @@ func pathConfigurePluginConnection(b *databaseBackend) *framework.Path { }, "verify_connection": &framework.FieldSchema{ - Type: framework.TypeBool, - Default: true, - Description: `If set, connection_url is verified by actually connecting to the database`, + Type: framework.TypeBool, + Default: true, + Description: `If set, the connection details are verified by + actually connecting to the database`, }, "plugin_name": &framework.FieldSchema{ Type: framework.TypeString, - Description: `Maximum amount of time a connection may be reused; - a zero or negative value reuses connections forever.`, + Description: `The name of a builtin or previously registered + plugin known to vault. This endpoint will create an instance of + that plugin type.`, }, }, @@ -198,16 +205,32 @@ func (b *databaseBackend) connectionWriteHandler() framework.OperationFunc { } const pathConfigConnectionHelpSyn = ` -Configure the connection string to talk to PostgreSQL. +Configure connection details to a database plugin. ` const pathConfigConnectionHelpDesc = ` -This path configures the connection string used to connect to PostgreSQL. -The value of the string can be a URL, or a PG style string in the -format of "user=foo host=bar" etc. +This path configures the connection details used to connect to a particular +database. This path runs the provided plugin name and passes the configured +connection details to the plugin. See the documentation for the plugin specified +for a full list of accepted connection details. -The URL looks like: -"postgresql://user:pass@host:port/dbname" +In addition to the database specific connection details, this endpoing also +accepts: + + * "plugin_name" (required) - The name of a builtin or previously registered + plugin known to vault. This endpoint will create an instance of that + plugin type. + + * "verify_connection" - A boolean value denoting if the plugin should verify + it is able to connect to the database using the provided connection + details. +` + +const pathResetConnectionHelpSyn = ` +Resets a database plugin. +` -When configuring the connection string, the backend will verify its validity. +const pathResetConnectionHelpDesc = ` +This path resets the database connection by closing the existing database plugin +instance and running a new one. ` diff --git a/builtin/logical/database/path_role_create.go b/builtin/logical/database/path_role_create.go index 5a16c8926254..59584e9437cf 100644 --- a/builtin/logical/database/path_role_create.go +++ b/builtin/logical/database/path_role_create.go @@ -19,7 +19,7 @@ func pathRoleCreate(b *databaseBackend) *framework.Path { }, Callbacks: map[logical.Operation]framework.OperationFunc{ - logical.ReadOperation: b.pathRoleCreateRead, + logical.ReadOperation: b.pathRoleCreateRead(), }, HelpSynopsis: pathRoleCreateReadHelpSyn, @@ -27,45 +27,47 @@ func pathRoleCreate(b *databaseBackend) *framework.Path { } } -func (b *databaseBackend) pathRoleCreateRead(req *logical.Request, data *framework.FieldData) (*logical.Response, error) { - name := data.Get("name").(string) +func (b *databaseBackend) pathRoleCreateRead() framework.OperationFunc { + return func(req *logical.Request, data *framework.FieldData) (*logical.Response, error) { + name := data.Get("name").(string) - // Get the role - role, err := b.Role(req.Storage, name) - if err != nil { - return nil, err - } - if role == nil { - return logical.ErrorResponse(fmt.Sprintf("Unknown role: %s", name)), nil - } + // Get the role + role, err := b.Role(req.Storage, name) + if err != nil { + return nil, err + } + if role == nil { + return logical.ErrorResponse(fmt.Sprintf("Unknown role: %s", name)), nil + } - b.Lock() - defer b.Unlock() + b.Lock() + defer b.Unlock() - // Get the Database object - db, err := b.getOrCreateDBObj(req.Storage, role.DBName) - if err != nil { - // TODO: return a resp error instead? - return nil, fmt.Errorf("cound not retrieve db with name: %s, got error: %s", role.DBName, err) - } + // Get the Database object + db, err := b.getOrCreateDBObj(req.Storage, role.DBName) + if err != nil { + // TODO: return a resp error instead? + return nil, fmt.Errorf("cound not retrieve db with name: %s, got error: %s", role.DBName, err) + } - expiration := time.Now().Add(role.DefaultTTL) + expiration := time.Now().Add(role.DefaultTTL) - // Create the user - username, password, err := db.CreateUser(role.Statements, req.DisplayName, expiration) - if err != nil { - return nil, err - } + // Create the user + username, password, err := db.CreateUser(role.Statements, req.DisplayName, expiration) + if err != nil { + return nil, err + } - resp := b.Secret(SecretCredsType).Response(map[string]interface{}{ - "username": username, - "password": password, - }, map[string]interface{}{ - "username": username, - "role": name, - }) - resp.Secret.TTL = role.DefaultTTL - return resp, nil + resp := b.Secret(SecretCredsType).Response(map[string]interface{}{ + "username": username, + "password": password, + }, map[string]interface{}{ + "username": username, + "role": name, + }) + resp.Secret.TTL = role.DefaultTTL + return resp, nil + } } const pathRoleCreateReadHelpSyn = ` diff --git a/builtin/logical/database/path_roles.go b/builtin/logical/database/path_roles.go index a6989df2481c..263a555e6b3b 100644 --- a/builtin/logical/database/path_roles.go +++ b/builtin/logical/database/path_roles.go @@ -14,7 +14,7 @@ func pathListRoles(b *databaseBackend) *framework.Path { Pattern: "roles/?$", Callbacks: map[logical.Operation]framework.OperationFunc{ - logical.ListOperation: b.pathRoleList, + logical.ListOperation: b.pathRoleList(), }, HelpSynopsis: pathRoleHelpSyn, @@ -35,12 +35,13 @@ func pathRoles(b *databaseBackend) *framework.Path { Type: framework.TypeString, Description: "Name of the database this role acts on.", }, - "creation_statements": { - Type: framework.TypeString, - Description: "SQL string to create a user. See help for more info.", + Type: framework.TypeString, + Description: `Statements to be executed to create a user. Must be a semicolon-separated + string, a base64-encoded semicolon-separated string, a serialized JSON string + array, or a base64-encoded serialized JSON string array. The '{{name}}', + '{{password}}', and '{{expiration}}' values will be substituted.`, }, - "revocation_statements": { Type: framework.TypeString, Description: `Statements to be executed to revoke a user. Must be a semicolon-separated @@ -75,9 +76,9 @@ func pathRoles(b *databaseBackend) *framework.Path { }, Callbacks: map[logical.Operation]framework.OperationFunc{ - logical.ReadOperation: b.pathRoleRead, - logical.UpdateOperation: b.pathRoleCreate, - logical.DeleteOperation: b.pathRoleDelete, + logical.ReadOperation: b.pathRoleRead(), + logical.UpdateOperation: b.pathRoleCreate(), + logical.DeleteOperation: b.pathRoleDelete(), }, HelpSynopsis: pathRoleHelpSyn, @@ -85,101 +86,107 @@ func pathRoles(b *databaseBackend) *framework.Path { } } -func (b *databaseBackend) pathRoleDelete(req *logical.Request, data *framework.FieldData) (*logical.Response, error) { - err := req.Storage.Delete("role/" + data.Get("name").(string)) - if err != nil { - return nil, err - } - - return nil, nil -} +func (b *databaseBackend) pathRoleDelete() framework.OperationFunc { + return func(req *logical.Request, data *framework.FieldData) (*logical.Response, error) { + err := req.Storage.Delete("role/" + data.Get("name").(string)) + if err != nil { + return nil, err + } -func (b *databaseBackend) pathRoleRead(req *logical.Request, data *framework.FieldData) (*logical.Response, error) { - role, err := b.Role(req.Storage, data.Get("name").(string)) - if err != nil { - return nil, err - } - if role == nil { return nil, nil } - - return &logical.Response{ - Data: map[string]interface{}{ - "creation_statements": role.Statements.CreationStatements, - "revocation_statements": role.Statements.RevocationStatements, - "rollback_statements": role.Statements.RollbackStatements, - "renew_statements": role.Statements.RenewStatements, - "default_ttl": role.DefaultTTL.String(), - "max_ttl": role.MaxTTL.String(), - }, - }, nil } -func (b *databaseBackend) pathRoleList(req *logical.Request, d *framework.FieldData) (*logical.Response, error) { - entries, err := req.Storage.List("role/") - if err != nil { - return nil, err +func (b *databaseBackend) pathRoleRead() framework.OperationFunc { + return func(req *logical.Request, data *framework.FieldData) (*logical.Response, error) { + role, err := b.Role(req.Storage, data.Get("name").(string)) + if err != nil { + return nil, err + } + if role == nil { + return nil, nil + } + + return &logical.Response{ + Data: map[string]interface{}{ + "creation_statements": role.Statements.CreationStatements, + "revocation_statements": role.Statements.RevocationStatements, + "rollback_statements": role.Statements.RollbackStatements, + "renew_statements": role.Statements.RenewStatements, + "default_ttl": role.DefaultTTL.String(), + "max_ttl": role.MaxTTL.String(), + }, + }, nil } - - return logical.ListResponse(entries), nil } -func (b *databaseBackend) pathRoleCreate(req *logical.Request, data *framework.FieldData) (*logical.Response, error) { - name := data.Get("name").(string) - if name == "" { - return logical.ErrorResponse("Empty role name attribute given"), nil - } +func (b *databaseBackend) pathRoleList() framework.OperationFunc { + return func(req *logical.Request, data *framework.FieldData) (*logical.Response, error) { + entries, err := req.Storage.List("role/") + if err != nil { + return nil, err + } - dbName := data.Get("db_name").(string) - if dbName == "" { - return logical.ErrorResponse("Empty database name attribute given"), nil - } - - // Get statements - creationStmts := data.Get("creation_statements").(string) - revocationStmts := data.Get("revocation_statements").(string) - rollbackStmts := data.Get("rollback_statements").(string) - renewStmts := data.Get("renew_statements").(string) - - // Get TTLs - defaultTTLRaw := data.Get("default_ttl").(string) - maxTTLRaw := data.Get("max_ttl").(string) - - defaultTTL, err := time.ParseDuration(defaultTTLRaw) - if err != nil { - return logical.ErrorResponse(fmt.Sprintf( - "Invalid default_ttl: %s", err)), nil - } - maxTTL, err := time.ParseDuration(maxTTLRaw) - if err != nil { - return logical.ErrorResponse(fmt.Sprintf( - "Invalid max_ttl: %s", err)), nil + return logical.ListResponse(entries), nil } +} - statements := dbplugin.Statements{ - CreationStatements: creationStmts, - RevocationStatements: revocationStmts, - RollbackStatements: rollbackStmts, - RenewStatements: renewStmts, - } +func (b *databaseBackend) pathRoleCreate() framework.OperationFunc { + return func(req *logical.Request, data *framework.FieldData) (*logical.Response, error) { + name := data.Get("name").(string) + if name == "" { + return logical.ErrorResponse("Empty role name attribute given"), nil + } + + dbName := data.Get("db_name").(string) + if dbName == "" { + return logical.ErrorResponse("Empty database name attribute given"), nil + } + + // Get statements + creationStmts := data.Get("creation_statements").(string) + revocationStmts := data.Get("revocation_statements").(string) + rollbackStmts := data.Get("rollback_statements").(string) + renewStmts := data.Get("renew_statements").(string) + + // Get TTLs + defaultTTLRaw := data.Get("default_ttl").(string) + maxTTLRaw := data.Get("max_ttl").(string) + + defaultTTL, err := time.ParseDuration(defaultTTLRaw) + if err != nil { + return logical.ErrorResponse(fmt.Sprintf( + "Invalid default_ttl: %s", err)), nil + } + maxTTL, err := time.ParseDuration(maxTTLRaw) + if err != nil { + return logical.ErrorResponse(fmt.Sprintf( + "Invalid max_ttl: %s", err)), nil + } + + statements := dbplugin.Statements{ + CreationStatements: creationStmts, + RevocationStatements: revocationStmts, + RollbackStatements: rollbackStmts, + RenewStatements: renewStmts, + } + + // Store it + entry, err := logical.StorageEntryJSON("role/"+name, &roleEntry{ + DBName: dbName, + Statements: statements, + DefaultTTL: defaultTTL, + MaxTTL: maxTTL, + }) + if err != nil { + return nil, err + } + if err := req.Storage.Put(entry); err != nil { + return nil, err + } - // TODO: Think about preparing the statments to test. - - // Store it - entry, err := logical.StorageEntryJSON("role/"+name, &roleEntry{ - DBName: dbName, - Statements: statements, - DefaultTTL: defaultTTL, - MaxTTL: maxTTL, - }) - if err != nil { - return nil, err - } - if err := req.Storage.Put(entry); err != nil { - return nil, err + return nil, nil } - - return nil, nil } type roleEntry struct { @@ -196,10 +203,14 @@ Manage the roles that can be created with this backend. const pathRoleHelpDesc = ` This path lets you manage the roles that can be created with this backend. -The "sql" parameter customizes the SQL string used to create the role. -This can be a sequence of SQL queries. Some substitution will be done to the -SQL string for certain keys. The names of the variables must be surrounded -by "{{" and "}}" to be replaced. +The "db_name" parameter is required and configures the name of the database +connection to use. + +The "creation_statements" parameter customizes the string used to create the +credentials. This can be a sequence of SQL queries, or other statement formats +for a particular database type. Some substitution will be done to the statement +strings for certain keys. The names of the variables must be surrounded by "{{" +and "}}" to be replaced. * "name" - The random username generated for the DB user. @@ -207,7 +218,7 @@ by "{{" and "}}" to be replaced. * "expiration" - The timestamp when this user will expire. -Example of a decent SQL query to use: +Example of a decent creation_statements for a postgresql database plugin: CREATE ROLE "{{name}}" WITH LOGIN @@ -215,14 +226,17 @@ Example of a decent SQL query to use: VALID UNTIL '{{expiration}}'; GRANT ALL PRIVILEGES ON ALL TABLES IN SCHEMA public TO "{{name}}"; -Note the above user would be able to access everything in schema public. -For more complex GRANT clauses, see the PostgreSQL manual. - -The "revocation_sql" parameter customizes the SQL string used to revoke a user. -Example of a decent revocation SQL query to use: +The "revocation_statements" parameter customizes the statement string used to +revoke a user. Example of a decent revocation_statements for a postgresql +database plugin: REVOKE ALL PRIVILEGES ON ALL TABLES IN SCHEMA public FROM {{name}}; REVOKE ALL PRIVILEGES ON ALL SEQUENCES IN SCHEMA public FROM {{name}}; REVOKE USAGE ON SCHEMA public FROM {{name}}; DROP ROLE IF EXISTS {{name}}; + +The "renew_statements" parameter customizes the statement string used to renew a +user. +The "rollback_statements' parameter customizes the statement string used to +rollback a change if needed. ` diff --git a/builtin/logical/database/secret_creds.go b/builtin/logical/database/secret_creds.go index 5701e373a6be..ffc59cf3fec3 100644 --- a/builtin/logical/database/secret_creds.go +++ b/builtin/logical/database/secret_creds.go @@ -14,112 +14,116 @@ func secretCreds(b *databaseBackend) *framework.Secret { Type: SecretCredsType, Fields: map[string]*framework.FieldSchema{}, - Renew: b.secretCredsRenew, - Revoke: b.secretCredsRevoke, + Renew: b.secretCredsRenew(), + Revoke: b.secretCredsRevoke(), } } -func (b *databaseBackend) secretCredsRenew(req *logical.Request, d *framework.FieldData) (*logical.Response, error) { - // Get the username from the internal data - usernameRaw, ok := req.Secret.InternalData["username"] - if !ok { - return nil, fmt.Errorf("secret is missing username internal data") - } - username, ok := usernameRaw.(string) - - roleNameRaw, ok := req.Secret.InternalData["role"] - if !ok { - return nil, fmt.Errorf("could not find role with name: %s", req.Secret.InternalData["role"]) - } +func (b *databaseBackend) secretCredsRenew() framework.OperationFunc { + return func(req *logical.Request, data *framework.FieldData) (*logical.Response, error) { + // Get the username from the internal data + usernameRaw, ok := req.Secret.InternalData["username"] + if !ok { + return nil, fmt.Errorf("secret is missing username internal data") + } + username, ok := usernameRaw.(string) - role, err := b.Role(req.Storage, roleNameRaw.(string)) - if err != nil { - return nil, err - } - if role == nil { - return nil, fmt.Errorf("could not find role with name: %s", req.Secret.InternalData["role"]) - } + roleNameRaw, ok := req.Secret.InternalData["role"] + if !ok { + return nil, fmt.Errorf("could not find role with name: %s", req.Secret.InternalData["role"]) + } - f := framework.LeaseExtend(role.DefaultTTL, role.MaxTTL, b.System()) - resp, err := f(req, d) - if err != nil { - return nil, err - } + role, err := b.Role(req.Storage, roleNameRaw.(string)) + if err != nil { + return nil, err + } + if role == nil { + return nil, fmt.Errorf("could not find role with name: %s", req.Secret.InternalData["role"]) + } - // Grab the read lock - b.Lock() - defer b.Unlock() + f := framework.LeaseExtend(role.DefaultTTL, role.MaxTTL, b.System()) + resp, err := f(req, data) + if err != nil { + return nil, err + } - // Get our connection - db, err := b.getOrCreateDBObj(req.Storage, role.DBName) - if err != nil { - return nil, fmt.Errorf("could not find connection with name %s, got err: %s", role.DBName, err) - } + // Grab the read lock + b.Lock() + defer b.Unlock() - // Make sure we increase the VALID UNTIL endpoint for this user. - if expireTime := resp.Secret.ExpirationTime(); !expireTime.IsZero() { - err := db.RenewUser(role.Statements, username, expireTime) + // Get our connection + db, err := b.getOrCreateDBObj(req.Storage, role.DBName) if err != nil { - return nil, err + return nil, fmt.Errorf("could not find connection with name %s, got err: %s", role.DBName, err) } - } - return resp, nil -} + // Make sure we increase the VALID UNTIL endpoint for this user. + if expireTime := resp.Secret.ExpirationTime(); !expireTime.IsZero() { + err := db.RenewUser(role.Statements, username, expireTime) + if err != nil { + return nil, err + } + } -func (b *databaseBackend) secretCredsRevoke(req *logical.Request, d *framework.FieldData) (*logical.Response, error) { - // Get the username from the internal data - usernameRaw, ok := req.Secret.InternalData["username"] - if !ok { - return nil, fmt.Errorf("secret is missing username internal data") + return resp, nil } - username, ok := usernameRaw.(string) +} - var resp *logical.Response +func (b *databaseBackend) secretCredsRevoke() framework.OperationFunc { + return func(req *logical.Request, data *framework.FieldData) (*logical.Response, error) { + // Get the username from the internal data + usernameRaw, ok := req.Secret.InternalData["username"] + if !ok { + return nil, fmt.Errorf("secret is missing username internal data") + } + username, ok := usernameRaw.(string) - roleNameRaw, ok := req.Secret.InternalData["role"] - if !ok { - return nil, fmt.Errorf("no role name was provided") - } + var resp *logical.Response - role, err := b.Role(req.Storage, roleNameRaw.(string)) - if err != nil { - return nil, err - } - if role == nil { - return nil, fmt.Errorf("could not find role with name: %s", req.Secret.InternalData["role"]) - } + roleNameRaw, ok := req.Secret.InternalData["role"] + if !ok { + return nil, fmt.Errorf("no role name was provided") + } - /* TODO: think about how to handle this case. - if !ok { role, err := b.Role(req.Storage, roleNameRaw.(string)) if err != nil { return nil, err } if role == nil { - if resp == nil { - resp = &logical.Response{} - } - resp.AddWarning(fmt.Sprintf("Role %q cannot be found. Using default revocation SQL.", roleNameRaw.(string))) - } else { - revocationSQL = role.RevocationStatement + return nil, fmt.Errorf("could not find role with name: %s", req.Secret.InternalData["role"]) } - }*/ - // Grab the read lock - b.Lock() - defer b.Unlock() + /* TODO: think about how to handle this case. + if !ok { + role, err := b.Role(req.Storage, roleNameRaw.(string)) + if err != nil { + return nil, err + } + if role == nil { + if resp == nil { + resp = &logical.Response{} + } + resp.AddWarning(fmt.Sprintf("Role %q cannot be found. Using default revocation SQL.", roleNameRaw.(string))) + } else { + revocationSQL = role.RevocationStatement + } + }*/ + + // Grab the read lock + b.Lock() + defer b.Unlock() - // Get our connection - db, err := b.getOrCreateDBObj(req.Storage, role.DBName) - if err != nil { - return nil, fmt.Errorf("could not find database with name: %s, got error: %s", role.DBName, err) - } + // Get our connection + db, err := b.getOrCreateDBObj(req.Storage, role.DBName) + if err != nil { + return nil, fmt.Errorf("could not find database with name: %s, got error: %s", role.DBName, err) + } - err = db.RevokeUser(role.Statements, username) - if err != nil { - return nil, err - } + err = db.RevokeUser(role.Statements, username) + if err != nil { + return nil, err + } - return resp, nil + return resp, nil + } } diff --git a/logical/system_view.go b/logical/system_view.go index b69f2709014e..b6ab14b1fc01 100644 --- a/logical/system_view.go +++ b/logical/system_view.go @@ -88,11 +88,11 @@ func (d StaticSystemView) ReplicationState() consts.ReplicationState { } func (d StaticSystemView) ResponseWrapData(data map[string]interface{}, ttl time.Duration, jwt bool) (string, error) { - return "", errors.New("ResponseWrapData is not implimented in StaticSystemView") + return "", errors.New("ResponseWrapData is not implemented in StaticSystemView") } func (d StaticSystemView) LookupPlugin(name string) (*pluginutil.PluginRunner, error) { - return nil, errors.New("LookupPlugin is not implimented in StaticSystemView") + return nil, errors.New("LookupPlugin is not implemented in StaticSystemView") } func (d StaticSystemView) MlockDisabled() bool { From 8c264c6070b8a763c6972aeb1a0d2c77d55e4550 Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Wed, 12 Apr 2017 09:40:54 -0700 Subject: [PATCH 068/152] Add remaining crud functions to plugin catalog and tests --- helper/builtinplugins/builtin.go | 12 +++ vault/logical_system.go | 31 ++++-- vault/plugin_catalog.go | 47 +++++++++ vault/plugin_catalog_test.go | 166 +++++++++++++++++++++++++++++++ 4 files changed, 249 insertions(+), 7 deletions(-) create mode 100644 vault/plugin_catalog_test.go diff --git a/helper/builtinplugins/builtin.go b/helper/builtinplugins/builtin.go index ba3769c900a1..55da9a97f310 100644 --- a/helper/builtinplugins/builtin.go +++ b/helper/builtinplugins/builtin.go @@ -22,3 +22,15 @@ func (b *builtinPlugins) Get(name string) (func() error, bool) { f, ok := b.plugins[name] return f, ok } + +func (b *builtinPlugins) Keys() []string { + keys := make([]string, len(b.plugins)) + + i := 0 + for k := range b.plugins { + keys[i] = k + i++ + } + + return keys +} diff --git a/vault/logical_system.go b/vault/logical_system.go index f5dbe2affa69..fadae02bf616 100644 --- a/vault/logical_system.go +++ b/vault/logical_system.go @@ -63,7 +63,6 @@ func NewSystemBackend(core *Core, config *logical.BackendConfig) (logical.Backen "replication/reindex", "rotate", "config/auditing/*", - "plugin-catalog", "plugin-catalog/*", }, @@ -694,6 +693,18 @@ func NewSystemBackend(core *Core, config *logical.BackendConfig) (logical.Backen HelpSynopsis: strings.TrimSpace(sysHelp["audited-headers"][0]), HelpDescription: strings.TrimSpace(sysHelp["audited-headers"][1]), }, + &framework.Path{ + Pattern: "plugin-catalog/$", + + Fields: map[string]*framework.FieldSchema{}, + + Callbacks: map[logical.Operation]framework.OperationFunc{ + logical.ListOperation: b.handlePluginCatalogList, + }, + + HelpSynopsis: strings.TrimSpace(sysHelp["audited-headers-name"][0]), + HelpDescription: strings.TrimSpace(sysHelp["audited-headers-name"][1]), + }, &framework.Path{ Pattern: "plugin-catalog/(?P.+)", @@ -750,6 +761,16 @@ func (b *SystemBackend) invalidate(key string) { } } +func (b *SystemBackend) handlePluginCatalogList(req *logical.Request, d *framework.FieldData) (*logical.Response, error) { + plugins, err := b.Core.pluginCatalog.List() + if err != nil { + return nil, err + } + + resp := logical.ListResponse(plugins) + return resp, nil +} + func (b *SystemBackend) handlePluginCatalogUpdate(req *logical.Request, d *framework.FieldData) (*logical.Response, error) { pluginName := d.Get("name").(string) if pluginName == "" { @@ -801,16 +822,12 @@ func (b *SystemBackend) handlePluginCatalogDelete(req *logical.Request, d *frame if pluginName == "" { return logical.ErrorResponse("missing plugin name"), nil } - plugin, err := b.Core.pluginCatalog.Get(pluginName) + err := b.Core.pluginCatalog.Delete(pluginName) if err != nil { return nil, err } - return &logical.Response{ - Data: map[string]interface{}{ - "plugin": plugin, - }, - }, nil + return nil, nil } // handleAuditedHeaderUpdate creates or overwrites a header entry diff --git a/vault/plugin_catalog.go b/vault/plugin_catalog.go index 737f0c26b91d..264a43d44339 100644 --- a/vault/plugin_catalog.go +++ b/vault/plugin_catalog.go @@ -5,6 +5,7 @@ import ( "errors" "fmt" "path/filepath" + "sort" "strings" "sync" @@ -39,6 +40,9 @@ func (c *Core) setupPluginCatalog() error { } func (c *PluginCatalog) Get(name string) (*pluginutil.PluginRunner, error) { + c.lock.RLock() + defer c.lock.RUnlock() + // Look for external plugins in the barrier out, err := c.catalogView.Get(name) if err != nil { @@ -68,6 +72,9 @@ func (c *PluginCatalog) Get(name string) (*pluginutil.PluginRunner, error) { } func (c *PluginCatalog) Set(name, command string, sha256 []byte) error { + c.lock.Lock() + defer c.lock.Unlock() + parts := strings.Split(command, " ") command = parts[0] args := parts[1:] @@ -111,3 +118,43 @@ func (c *PluginCatalog) Set(name, command string, sha256 []byte) error { } return nil } + +func (c *PluginCatalog) Delete(name string) error { + c.lock.Lock() + defer c.lock.Unlock() + + return c.catalogView.Delete(name) +} + +func (c *PluginCatalog) List() ([]string, error) { + c.lock.RLock() + defer c.lock.RUnlock() + + keys, err := logical.CollectKeys(c.catalogView) + if err != nil { + return nil, err + } + + builtinKeys := builtinplugins.BuiltinPlugins.Keys() + + mapKeys := make(map[string]bool) + + for _, plugin := range keys { + mapKeys[plugin] = true + } + + for _, plugin := range builtinKeys { + mapKeys[plugin] = true + } + + retList := make([]string, len(mapKeys)) + i := 0 + for k := range mapKeys { + retList[i] = k + i++ + } + // sort for consistent ordering of builtin pluings + sort.Strings(retList) + + return retList, nil +} diff --git a/vault/plugin_catalog_test.go b/vault/plugin_catalog_test.go new file mode 100644 index 000000000000..e78e7d963013 --- /dev/null +++ b/vault/plugin_catalog_test.go @@ -0,0 +1,166 @@ +package vault + +import ( + "fmt" + "io/ioutil" + "os" + "path/filepath" + "reflect" + "sort" + "testing" + + "github.com/hashicorp/vault/helper/builtinplugins" + "github.com/hashicorp/vault/helper/pluginutil" +) + +func TestPluginCatalog_CRUD(t *testing.T) { + core, _, _ := TestCoreUnsealed(t) + + sym, err := filepath.EvalSymlinks(os.TempDir()) + if err != nil { + t.Fatalf("error: %v", err) + } + core.pluginCatalog.directory = sym + core.pluginCatalog.vaultCommand = "vault" + core.pluginCatalog.vaultSHA256 = []byte{'1'} + + // Get builtin plugin + p, err := core.pluginCatalog.Get("mysql-database-plugin") + if err != nil { + t.Fatalf("unexpected error %v", err) + } + + expectedBuiltin := &pluginutil.PluginRunner{ + Name: "mysql-database-plugin", + Command: "vault", + Args: []string{"plugin-exec", "mysql-database-plugin"}, + Sha256: []byte{'1'}, + Builtin: true, + } + + if !reflect.DeepEqual(p, expectedBuiltin) { + t.Fatalf("expected did not match actual, got %#v\n expected %#v\n", p, expectedBuiltin) + } + + // Set a plugin, test overwriting a builtin plugin + file, err := ioutil.TempFile(os.TempDir(), "temp") + if err != nil { + t.Fatal(err) + } + defer file.Close() + + command := fmt.Sprintf("%s --test", filepath.Base(file.Name())) + err = core.pluginCatalog.Set("mysql-database-plugin", command, []byte{'1'}) + if err != nil { + t.Fatal(err) + } + + // Get the plugin + p, err = core.pluginCatalog.Get("mysql-database-plugin") + if err != nil { + t.Fatalf("unexpected error %v", err) + } + + expected := &pluginutil.PluginRunner{ + Name: "mysql-database-plugin", + Command: filepath.Join(sym, filepath.Base(file.Name())), + Args: []string{"--test"}, + Sha256: []byte{'1'}, + Builtin: false, + } + + if !reflect.DeepEqual(p, expected) { + t.Fatalf("expected did not match actual, got %#v\n expected %#v\n", p, expected) + } + + // Delete the plugin + err = core.pluginCatalog.Delete("mysql-database-plugin") + if err != nil { + t.Fatalf("unexpected err: %v", err) + } + + // Get builtin plugin + p, err = core.pluginCatalog.Get("mysql-database-plugin") + if err != nil { + t.Fatalf("unexpected error %v", err) + } + + if !reflect.DeepEqual(p, expectedBuiltin) { + t.Fatalf("expected did not match actual, got %#v\n expected %#v\n", p, expectedBuiltin) + } + +} + +func TestPluginCatalog_List(t *testing.T) { + core, _, _ := TestCoreUnsealed(t) + + sym, err := filepath.EvalSymlinks(os.TempDir()) + if err != nil { + t.Fatalf("error: %v", err) + } + core.pluginCatalog.directory = sym + core.pluginCatalog.vaultCommand = "vault" + core.pluginCatalog.vaultSHA256 = []byte{'1'} + + // Get builtin plugins and sort them + builtinKeys := builtinplugins.BuiltinPlugins.Keys() + sort.Strings(builtinKeys) + + // List only builtin plugins + plugins, err := core.pluginCatalog.List() + if err != nil { + t.Fatalf("unexpected error %v", err) + } + + if len(plugins) != len(builtinKeys) { + t.Fatalf("unexpected length of plugin list, expected %d, got %d", len(builtinKeys), len(plugins)) + } + + for i, p := range builtinKeys { + if !reflect.DeepEqual(plugins[i], p) { + t.Fatalf("expected did not match actual, got %#v\n expected %#v\n", plugins[i], p) + } + } + + // Set a plugin, test overwriting a builtin plugin + file, err := ioutil.TempFile(os.TempDir(), "temp") + if err != nil { + t.Fatal(err) + } + defer file.Close() + + command := fmt.Sprintf("%s --test", filepath.Base(file.Name())) + err = core.pluginCatalog.Set("mysql-database-plugin", command, []byte{'1'}) + if err != nil { + t.Fatal(err) + } + + // Set another plugin + err = core.pluginCatalog.Set("aaaaaaa", command, []byte{'1'}) + if err != nil { + t.Fatal(err) + } + + // List the plugins + plugins, err = core.pluginCatalog.List() + if err != nil { + t.Fatalf("unexpected error %v", err) + } + + if len(plugins) != len(builtinKeys)+1 { + t.Fatalf("unexpected length of plugin list, expected %d, got %d", len(builtinKeys)+1, len(plugins)) + } + + // verify the first plugin is the one we just created. + if !reflect.DeepEqual(plugins[0], "aaaaaaa") { + t.Fatalf("expected did not match actual, got %#v\n expected %#v\n", plugins[0], "aaaaaaa") + } + + // verify the builtin pluings are correct + for i, p := range builtinKeys { + if !reflect.DeepEqual(plugins[i+1], p) { + t.Fatalf("expected did not match actual, got %#v\n expected %#v\n", plugins[i+1], p) + } + } + +} From 0e08279131eb8e62683b2fd7d1380785bb7eebd5 Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Wed, 12 Apr 2017 10:01:36 -0700 Subject: [PATCH 069/152] Add path help and comments for plugin-catalog --- vault/logical_system.go | 25 +++++++++++++++++++++---- vault/plugin_catalog.go | 15 +++++++++++++++ 2 files changed, 36 insertions(+), 4 deletions(-) diff --git a/vault/logical_system.go b/vault/logical_system.go index fadae02bf616..2aff5d0e1b73 100644 --- a/vault/logical_system.go +++ b/vault/logical_system.go @@ -702,8 +702,8 @@ func NewSystemBackend(core *Core, config *logical.BackendConfig) (logical.Backen logical.ListOperation: b.handlePluginCatalogList, }, - HelpSynopsis: strings.TrimSpace(sysHelp["audited-headers-name"][0]), - HelpDescription: strings.TrimSpace(sysHelp["audited-headers-name"][1]), + HelpSynopsis: strings.TrimSpace(sysHelp["plugin-catalog"][0]), + HelpDescription: strings.TrimSpace(sysHelp["plugin-catalog"][1]), }, &framework.Path{ Pattern: "plugin-catalog/(?P.+)", @@ -726,8 +726,8 @@ func NewSystemBackend(core *Core, config *logical.BackendConfig) (logical.Backen logical.ReadOperation: b.handlePluginCatalogRead, }, - HelpSynopsis: strings.TrimSpace(sysHelp["audited-headers-name"][0]), - HelpDescription: strings.TrimSpace(sysHelp["audited-headers-name"][1]), + HelpSynopsis: strings.TrimSpace(sysHelp["plugin-catalog"][0]), + HelpDescription: strings.TrimSpace(sysHelp["plugin-catalog"][1]), }, }, } @@ -2506,4 +2506,21 @@ This path responds to the following HTTP methods. "Lists the headers configured to be audited.", `Returns a list of headers that have been configured to be audited.`, }, + "plugin-catalog": { + `Configures the plugins known to vault`, + ` +This path responds to the following HTTP methods. + GET / + Returns a list of names of configured plugins. + + GET / + Retrieve the metadata for the named plugin. + + PUT / + Add or update plugin. + + DELETE / + Delete the plugin with the given name. + `, + }, } diff --git a/vault/plugin_catalog.go b/vault/plugin_catalog.go index 264a43d44339..b89224780c40 100644 --- a/vault/plugin_catalog.go +++ b/vault/plugin_catalog.go @@ -19,6 +19,9 @@ var ( pluginCatalogPrefix = "plugin-catalog/" ) +// PluginCatalog keeps a record of plugins known to vault. External plugins need +// to be registered to the catalog before they can be used in backends. Builtin +// plugins are automatically detected and included in the catalog. type PluginCatalog struct { catalogView *BarrierView directory string @@ -39,6 +42,9 @@ func (c *Core) setupPluginCatalog() error { return nil } +// Get retrieves a plugin with the specified name from the catalog. It first +// looks for external plugins with this name and then looks for builtin plugins. +// It returns a PluginRunner or an error if no plugin was found. func (c *PluginCatalog) Get(name string) (*pluginutil.PluginRunner, error) { c.lock.RLock() defer c.lock.RUnlock() @@ -71,6 +77,8 @@ func (c *PluginCatalog) Get(name string) (*pluginutil.PluginRunner, error) { }, nil } +// Set registers a new external plugin with the catalog, or updates an existing +// external plugin. It takes the name, command and SHA256 of the plugin. func (c *PluginCatalog) Set(name, command string, sha256 []byte) error { c.lock.Lock() defer c.lock.Unlock() @@ -119,6 +127,8 @@ func (c *PluginCatalog) Set(name, command string, sha256 []byte) error { return nil } +// Delete is used to remove an external plugin from the catalog. Builtin plugins +// can not be deleted. func (c *PluginCatalog) Delete(name string) error { c.lock.Lock() defer c.lock.Unlock() @@ -126,17 +136,22 @@ func (c *PluginCatalog) Delete(name string) error { return c.catalogView.Delete(name) } +// List returns a list of all the known plugin names. If an external and builtin +// plugin share the same name, only one instance of the name will be returned. func (c *PluginCatalog) List() ([]string, error) { c.lock.RLock() defer c.lock.RUnlock() + // Collect keys for external plugins in the barrier. keys, err := logical.CollectKeys(c.catalogView) if err != nil { return nil, err } + // Get the keys for builtin plugins builtinKeys := builtinplugins.BuiltinPlugins.Keys() + // Use a map to unique the two lists mapKeys := make(map[string]bool) for _, plugin := range keys { From cb844b5113faa2dcf52bc75635988920137ca789 Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Wed, 12 Apr 2017 10:39:18 -0700 Subject: [PATCH 070/152] Add test for logical_system plugin-catalog handling --- vault/logical_system_test.go | 93 ++++++++++++++++++++++++++++++++++++ 1 file changed, 93 insertions(+) diff --git a/vault/logical_system_test.go b/vault/logical_system_test.go index 15f60a50c25e..3c808677f9ed 100644 --- a/vault/logical_system_test.go +++ b/vault/logical_system_test.go @@ -2,6 +2,11 @@ package vault import ( "crypto/sha256" + "encoding/hex" + "fmt" + "io/ioutil" + "os" + "path/filepath" "reflect" "strings" "testing" @@ -9,6 +14,8 @@ import ( "github.com/fatih/structs" "github.com/hashicorp/vault/audit" + "github.com/hashicorp/vault/helper/builtinplugins" + "github.com/hashicorp/vault/helper/pluginutil" "github.com/hashicorp/vault/helper/salt" "github.com/hashicorp/vault/logical" ) @@ -1076,3 +1083,89 @@ func testCoreSystemBackend(t *testing.T) (*Core, logical.Backend, string) { } return c, b, root } + +func TestSystemBackend_PluginCatalog_CRUD(t *testing.T) { + c, b, _ := testCoreSystemBackend(t) + // Bootstrap the pluginCatalog + sym, err := filepath.EvalSymlinks(os.TempDir()) + if err != nil { + t.Fatalf("error: %v", err) + } + c.pluginCatalog.directory = sym + c.pluginCatalog.vaultCommand = "vault" + c.pluginCatalog.vaultSHA256 = []byte{'1'} + + req := logical.TestRequest(t, logical.ListOperation, "plugin-catalog/") + resp, err := b.HandleRequest(req) + if err != nil { + t.Fatalf("err: %v", err) + } + + if len(resp.Data["keys"].([]string)) != len(builtinplugins.BuiltinPlugins.Keys()) { + t.Fatalf("Wrong number of plugins, got %d, expected %d", len(resp.Data["keys"].([]string)), len(builtinplugins.BuiltinPlugins.Keys())) + } + + req = logical.TestRequest(t, logical.ReadOperation, "plugin-catalog/mysql-database-plugin") + resp, err = b.HandleRequest(req) + if err != nil { + t.Fatalf("err: %v", err) + } + + expectedBuiltin := &pluginutil.PluginRunner{ + Name: "mysql-database-plugin", + Command: "vault", + Args: []string{"plugin-exec", "mysql-database-plugin"}, + Sha256: []byte{'1'}, + Builtin: true, + } + + if !reflect.DeepEqual(resp.Data["plugin"].(*pluginutil.PluginRunner), expectedBuiltin) { + t.Fatalf("expected did not match actual, got %#v\n expected %#v\n", resp.Data["plugin"].(*pluginutil.PluginRunner), expectedBuiltin) + } + + // Set a plugin + file, err := ioutil.TempFile(os.TempDir(), "temp") + if err != nil { + t.Fatal(err) + } + defer file.Close() + + command := fmt.Sprintf("%s --test", filepath.Base(file.Name())) + req = logical.TestRequest(t, logical.UpdateOperation, "plugin-catalog/test-plugin") + req.Data["sha_256"] = hex.EncodeToString([]byte{'1'}) + req.Data["command"] = command + resp, err = b.HandleRequest(req) + if err != nil { + t.Fatalf("err: %v", err) + } + + req = logical.TestRequest(t, logical.ReadOperation, "plugin-catalog/test-plugin") + resp, err = b.HandleRequest(req) + if err != nil { + t.Fatalf("err: %v", err) + } + + expected := &pluginutil.PluginRunner{ + Name: "test-plugin", + Command: filepath.Join(sym, filepath.Base(file.Name())), + Args: []string{"--test"}, + Sha256: []byte{'1'}, + Builtin: false, + } + if !reflect.DeepEqual(resp.Data["plugin"].(*pluginutil.PluginRunner), expected) { + t.Fatalf("expected did not match actual, got %#v\n expected %#v\n", resp.Data["plugin"].(*pluginutil.PluginRunner), expected) + } + + // Delete plugin + req = logical.TestRequest(t, logical.DeleteOperation, "plugin-catalog/test-plugin") + resp, err = b.HandleRequest(req) + if err != nil { + t.Fatalf("err: %v", err) + } + + req = logical.TestRequest(t, logical.ReadOperation, "plugin-catalog/test-plugin") + resp, err = b.HandleRequest(req) + if err == nil { + t.Fatalf("expected error, plugin not deleted correctly") + } +} From 1bc0243113be11266b2ea8221295c8339605939a Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Wed, 12 Apr 2017 14:22:52 -0700 Subject: [PATCH 071/152] Fix RootPaths test --- vault/logical_system_test.go | 1 + 1 file changed, 1 insertion(+) diff --git a/vault/logical_system_test.go b/vault/logical_system_test.go index 3c808677f9ed..c608b3b8670b 100644 --- a/vault/logical_system_test.go +++ b/vault/logical_system_test.go @@ -32,6 +32,7 @@ func TestSystemBackend_RootPaths(t *testing.T) { "replication/reindex", "rotate", "config/auditing/*", + "plugin-catalog/*", } b := testSystemBackend(t) From c9dc7b800b2aaf5f850a7796f137bed8c2d7e7d5 Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Wed, 12 Apr 2017 14:23:15 -0700 Subject: [PATCH 072/152] vendor go-plugin --- vendor/github.com/hashicorp/go-plugin/LICENSE | 353 ++++++++++ .../github.com/hashicorp/go-plugin/README.md | 161 +++++ .../github.com/hashicorp/go-plugin/client.go | 666 ++++++++++++++++++ .../hashicorp/go-plugin/discover.go | 28 + .../github.com/hashicorp/go-plugin/error.go | 24 + .../hashicorp/go-plugin/mux_broker.go | 204 ++++++ .../github.com/hashicorp/go-plugin/plugin.go | 25 + .../github.com/hashicorp/go-plugin/process.go | 24 + .../hashicorp/go-plugin/process_posix.go | 19 + .../hashicorp/go-plugin/process_windows.go | 29 + .../hashicorp/go-plugin/rpc_client.go | 123 ++++ .../hashicorp/go-plugin/rpc_server.go | 185 +++++ .../github.com/hashicorp/go-plugin/server.go | 235 ++++++ .../hashicorp/go-plugin/server_mux.go | 31 + .../github.com/hashicorp/go-plugin/stream.go | 18 + .../github.com/hashicorp/go-plugin/testing.go | 76 ++ vendor/vendor.json | 6 + 17 files changed, 2207 insertions(+) create mode 100644 vendor/github.com/hashicorp/go-plugin/LICENSE create mode 100644 vendor/github.com/hashicorp/go-plugin/README.md create mode 100644 vendor/github.com/hashicorp/go-plugin/client.go create mode 100644 vendor/github.com/hashicorp/go-plugin/discover.go create mode 100644 vendor/github.com/hashicorp/go-plugin/error.go create mode 100644 vendor/github.com/hashicorp/go-plugin/mux_broker.go create mode 100644 vendor/github.com/hashicorp/go-plugin/plugin.go create mode 100644 vendor/github.com/hashicorp/go-plugin/process.go create mode 100644 vendor/github.com/hashicorp/go-plugin/process_posix.go create mode 100644 vendor/github.com/hashicorp/go-plugin/process_windows.go create mode 100644 vendor/github.com/hashicorp/go-plugin/rpc_client.go create mode 100644 vendor/github.com/hashicorp/go-plugin/rpc_server.go create mode 100644 vendor/github.com/hashicorp/go-plugin/server.go create mode 100644 vendor/github.com/hashicorp/go-plugin/server_mux.go create mode 100644 vendor/github.com/hashicorp/go-plugin/stream.go create mode 100644 vendor/github.com/hashicorp/go-plugin/testing.go diff --git a/vendor/github.com/hashicorp/go-plugin/LICENSE b/vendor/github.com/hashicorp/go-plugin/LICENSE new file mode 100644 index 000000000000..82b4de97c7e3 --- /dev/null +++ b/vendor/github.com/hashicorp/go-plugin/LICENSE @@ -0,0 +1,353 @@ +Mozilla Public License, version 2.0 + +1. Definitions + +1.1. “Contributor” + + means each individual or legal entity that creates, contributes to the + creation of, or owns Covered Software. + +1.2. “Contributor Version” + + means the combination of the Contributions of others (if any) used by a + Contributor and that particular Contributor’s Contribution. + +1.3. “Contribution” + + means Covered Software of a particular Contributor. + +1.4. “Covered Software” + + means Source Code Form to which the initial Contributor has attached the + notice in Exhibit A, the Executable Form of such Source Code Form, and + Modifications of such Source Code Form, in each case including portions + thereof. + +1.5. “Incompatible With Secondary Licenses” + means + + a. that the initial Contributor has attached the notice described in + Exhibit B to the Covered Software; or + + b. that the Covered Software was made available under the terms of version + 1.1 or earlier of the License, but not also under the terms of a + Secondary License. + +1.6. “Executable Form” + + means any form of the work other than Source Code Form. + +1.7. “Larger Work” + + means a work that combines Covered Software with other material, in a separate + file or files, that is not Covered Software. + +1.8. “License” + + means this document. + +1.9. “Licensable” + + means having the right to grant, to the maximum extent possible, whether at the + time of the initial grant or subsequently, any and all of the rights conveyed by + this License. + +1.10. “Modifications” + + means any of the following: + + a. any file in Source Code Form that results from an addition to, deletion + from, or modification of the contents of Covered Software; or + + b. any new file in Source Code Form that contains any Covered Software. + +1.11. “Patent Claims” of a Contributor + + means any patent claim(s), including without limitation, method, process, + and apparatus claims, in any patent Licensable by such Contributor that + would be infringed, but for the grant of the License, by the making, + using, selling, offering for sale, having made, import, or transfer of + either its Contributions or its Contributor Version. + +1.12. “Secondary License” + + means either the GNU General Public License, Version 2.0, the GNU Lesser + General Public License, Version 2.1, the GNU Affero General Public + License, Version 3.0, or any later versions of those licenses. + +1.13. “Source Code Form” + + means the form of the work preferred for making modifications. + +1.14. “You” (or “Your”) + + means an individual or a legal entity exercising rights under this + License. For legal entities, “You” includes any entity that controls, is + controlled by, or is under common control with You. For purposes of this + definition, “control” means (a) the power, direct or indirect, to cause + the direction or management of such entity, whether by contract or + otherwise, or (b) ownership of more than fifty percent (50%) of the + outstanding shares or beneficial ownership of such entity. + + +2. License Grants and Conditions + +2.1. Grants + + Each Contributor hereby grants You a world-wide, royalty-free, + non-exclusive license: + + a. under intellectual property rights (other than patent or trademark) + Licensable by such Contributor to use, reproduce, make available, + modify, display, perform, distribute, and otherwise exploit its + Contributions, either on an unmodified basis, with Modifications, or as + part of a Larger Work; and + + b. under Patent Claims of such Contributor to make, use, sell, offer for + sale, have made, import, and otherwise transfer either its Contributions + or its Contributor Version. + +2.2. Effective Date + + The licenses granted in Section 2.1 with respect to any Contribution become + effective for each Contribution on the date the Contributor first distributes + such Contribution. + +2.3. Limitations on Grant Scope + + The licenses granted in this Section 2 are the only rights granted under this + License. No additional rights or licenses will be implied from the distribution + or licensing of Covered Software under this License. Notwithstanding Section + 2.1(b) above, no patent license is granted by a Contributor: + + a. for any code that a Contributor has removed from Covered Software; or + + b. for infringements caused by: (i) Your and any other third party’s + modifications of Covered Software, or (ii) the combination of its + Contributions with other software (except as part of its Contributor + Version); or + + c. under Patent Claims infringed by Covered Software in the absence of its + Contributions. + + This License does not grant any rights in the trademarks, service marks, or + logos of any Contributor (except as may be necessary to comply with the + notice requirements in Section 3.4). + +2.4. Subsequent Licenses + + No Contributor makes additional grants as a result of Your choice to + distribute the Covered Software under a subsequent version of this License + (see Section 10.2) or under the terms of a Secondary License (if permitted + under the terms of Section 3.3). + +2.5. Representation + + Each Contributor represents that the Contributor believes its Contributions + are its original creation(s) or it has sufficient rights to grant the + rights to its Contributions conveyed by this License. + +2.6. Fair Use + + This License is not intended to limit any rights You have under applicable + copyright doctrines of fair use, fair dealing, or other equivalents. + +2.7. Conditions + + Sections 3.1, 3.2, 3.3, and 3.4 are conditions of the licenses granted in + Section 2.1. + + +3. Responsibilities + +3.1. Distribution of Source Form + + All distribution of Covered Software in Source Code Form, including any + Modifications that You create or to which You contribute, must be under the + terms of this License. You must inform recipients that the Source Code Form + of the Covered Software is governed by the terms of this License, and how + they can obtain a copy of this License. You may not attempt to alter or + restrict the recipients’ rights in the Source Code Form. + +3.2. Distribution of Executable Form + + If You distribute Covered Software in Executable Form then: + + a. such Covered Software must also be made available in Source Code Form, + as described in Section 3.1, and You must inform recipients of the + Executable Form how they can obtain a copy of such Source Code Form by + reasonable means in a timely manner, at a charge no more than the cost + of distribution to the recipient; and + + b. You may distribute such Executable Form under the terms of this License, + or sublicense it under different terms, provided that the license for + the Executable Form does not attempt to limit or alter the recipients’ + rights in the Source Code Form under this License. + +3.3. Distribution of a Larger Work + + You may create and distribute a Larger Work under terms of Your choice, + provided that You also comply with the requirements of this License for the + Covered Software. If the Larger Work is a combination of Covered Software + with a work governed by one or more Secondary Licenses, and the Covered + Software is not Incompatible With Secondary Licenses, this License permits + You to additionally distribute such Covered Software under the terms of + such Secondary License(s), so that the recipient of the Larger Work may, at + their option, further distribute the Covered Software under the terms of + either this License or such Secondary License(s). + +3.4. Notices + + You may not remove or alter the substance of any license notices (including + copyright notices, patent notices, disclaimers of warranty, or limitations + of liability) contained within the Source Code Form of the Covered + Software, except that You may alter any license notices to the extent + required to remedy known factual inaccuracies. + +3.5. Application of Additional Terms + + You may choose to offer, and to charge a fee for, warranty, support, + indemnity or liability obligations to one or more recipients of Covered + Software. However, You may do so only on Your own behalf, and not on behalf + of any Contributor. You must make it absolutely clear that any such + warranty, support, indemnity, or liability obligation is offered by You + alone, and You hereby agree to indemnify every Contributor for any + liability incurred by such Contributor as a result of warranty, support, + indemnity or liability terms You offer. You may include additional + disclaimers of warranty and limitations of liability specific to any + jurisdiction. + +4. Inability to Comply Due to Statute or Regulation + + If it is impossible for You to comply with any of the terms of this License + with respect to some or all of the Covered Software due to statute, judicial + order, or regulation then You must: (a) comply with the terms of this License + to the maximum extent possible; and (b) describe the limitations and the code + they affect. Such description must be placed in a text file included with all + distributions of the Covered Software under this License. Except to the + extent prohibited by statute or regulation, such description must be + sufficiently detailed for a recipient of ordinary skill to be able to + understand it. + +5. Termination + +5.1. The rights granted under this License will terminate automatically if You + fail to comply with any of its terms. However, if You become compliant, + then the rights granted under this License from a particular Contributor + are reinstated (a) provisionally, unless and until such Contributor + explicitly and finally terminates Your grants, and (b) on an ongoing basis, + if such Contributor fails to notify You of the non-compliance by some + reasonable means prior to 60 days after You have come back into compliance. + Moreover, Your grants from a particular Contributor are reinstated on an + ongoing basis if such Contributor notifies You of the non-compliance by + some reasonable means, this is the first time You have received notice of + non-compliance with this License from such Contributor, and You become + compliant prior to 30 days after Your receipt of the notice. + +5.2. If You initiate litigation against any entity by asserting a patent + infringement claim (excluding declaratory judgment actions, counter-claims, + and cross-claims) alleging that a Contributor Version directly or + indirectly infringes any patent, then the rights granted to You by any and + all Contributors for the Covered Software under Section 2.1 of this License + shall terminate. + +5.3. In the event of termination under Sections 5.1 or 5.2 above, all end user + license agreements (excluding distributors and resellers) which have been + validly granted by You or Your distributors under this License prior to + termination shall survive termination. + +6. Disclaimer of Warranty + + Covered Software is provided under this License on an “as is” basis, without + warranty of any kind, either expressed, implied, or statutory, including, + without limitation, warranties that the Covered Software is free of defects, + merchantable, fit for a particular purpose or non-infringing. The entire + risk as to the quality and performance of the Covered Software is with You. + Should any Covered Software prove defective in any respect, You (not any + Contributor) assume the cost of any necessary servicing, repair, or + correction. This disclaimer of warranty constitutes an essential part of this + License. No use of any Covered Software is authorized under this License + except under this disclaimer. + +7. Limitation of Liability + + Under no circumstances and under no legal theory, whether tort (including + negligence), contract, or otherwise, shall any Contributor, or anyone who + distributes Covered Software as permitted above, be liable to You for any + direct, indirect, special, incidental, or consequential damages of any + character including, without limitation, damages for lost profits, loss of + goodwill, work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses, even if such party shall have been + informed of the possibility of such damages. This limitation of liability + shall not apply to liability for death or personal injury resulting from such + party’s negligence to the extent applicable law prohibits such limitation. + Some jurisdictions do not allow the exclusion or limitation of incidental or + consequential damages, so this exclusion and limitation may not apply to You. + +8. Litigation + + Any litigation relating to this License may be brought only in the courts of + a jurisdiction where the defendant maintains its principal place of business + and such litigation shall be governed by laws of that jurisdiction, without + reference to its conflict-of-law provisions. Nothing in this Section shall + prevent a party’s ability to bring cross-claims or counter-claims. + +9. Miscellaneous + + This License represents the complete agreement concerning the subject matter + hereof. If any provision of this License is held to be unenforceable, such + provision shall be reformed only to the extent necessary to make it + enforceable. Any law or regulation which provides that the language of a + contract shall be construed against the drafter shall not be used to construe + this License against a Contributor. + + +10. Versions of the License + +10.1. New Versions + + Mozilla Foundation is the license steward. Except as provided in Section + 10.3, no one other than the license steward has the right to modify or + publish new versions of this License. Each version will be given a + distinguishing version number. + +10.2. Effect of New Versions + + You may distribute the Covered Software under the terms of the version of + the License under which You originally received the Covered Software, or + under the terms of any subsequent version published by the license + steward. + +10.3. Modified Versions + + If you create software not governed by this License, and you want to + create a new license for such software, you may create and use a modified + version of this License if you rename the license and remove any + references to the name of the license steward (except to note that such + modified license differs from this License). + +10.4. Distributing Source Code Form that is Incompatible With Secondary Licenses + If You choose to distribute Source Code Form that is Incompatible With + Secondary Licenses under the terms of this version of the License, the + notice described in Exhibit B of this License must be attached. + +Exhibit A - Source Code Form License Notice + + This Source Code Form is subject to the + terms of the Mozilla Public License, v. + 2.0. If a copy of the MPL was not + distributed with this file, You can + obtain one at + http://mozilla.org/MPL/2.0/. + +If it is not possible or desirable to put the notice in a particular file, then +You may include the notice in a location (such as a LICENSE file in a relevant +directory) where a recipient would be likely to look for such a notice. + +You may add additional accurate notices of copyright ownership. + +Exhibit B - “Incompatible With Secondary Licenses” Notice + + This Source Code Form is “Incompatible + With Secondary Licenses”, as defined by + the Mozilla Public License, v. 2.0. diff --git a/vendor/github.com/hashicorp/go-plugin/README.md b/vendor/github.com/hashicorp/go-plugin/README.md new file mode 100644 index 000000000000..2058cfb68d19 --- /dev/null +++ b/vendor/github.com/hashicorp/go-plugin/README.md @@ -0,0 +1,161 @@ +# Go Plugin System over RPC + +`go-plugin` is a Go (golang) plugin system over RPC. It is the plugin system +that has been in use by HashiCorp tooling for over 3 years. While initially +created for [Packer](https://www.packer.io), it has since been used by +[Terraform](https://www.terraform.io) and [Otto](https://www.ottoproject.io), +with plans to also use it for [Nomad](https://www.nomadproject.io) and +[Vault](https://www.vaultproject.io). + +While the plugin system is over RPC, it is currently only designed to work +over a local [reliable] network. Plugins over a real network are not supported +and will lead to unexpected behavior. + +This plugin system has been used on millions of machines across many different +projects and has proven to be battle hardened and ready for production use. + +## Features + +The HashiCorp plugin system supports a number of features: + +**Plugins are Go interface implementations.** This makes writing and consuming +plugins feel very natural. To a plugin author: you just implement an +interface as if it were going to run in the same process. For a plugin user: +you just use and call functions on an interface as if it were in the same +process. This plugin system handles the communication in between. + +**Complex arguments and return values are supported.** This library +provides APIs for handling complex arguments and return values such +as interfaces, `io.Reader/Writer`, etc. We do this by giving you a library +(`MuxBroker`) for creating new connections between the client/server to +serve additional interfaces or transfer raw data. + +**Bidirectional communication.** Because the plugin system supports +complex arguments, the host process can send it interface implementations +and the plugin can call back into the host process. + +**Built-in Logging.** Any plugins that use the `log` standard library +will have log data automatically sent to the host process. The host +process will mirror this output prefixed with the path to the plugin +binary. This makes debugging with plugins simple. + +**Protocol Versioning.** A very basic "protocol version" is supported that +can be incremented to invalidate any previous plugins. This is useful when +interface signatures are changing, protocol level changes are necessary, +etc. When a protocol version is incompatible, a human friendly error +message is shown to the end user. + +**Stdout/Stderr Syncing.** While plugins are subprocesses, they can continue +to use stdout/stderr as usual and the output will get mirrored back to +the host process. The host process can control what `io.Writer` these +streams go to to prevent this from happening. + +**TTY Preservation.** Plugin subprocesses are connected to the identical +stdin file descriptor as the host process, allowing software that requires +a TTY to work. For example, a plugin can execute `ssh` and even though there +are multiple subprocesses and RPC happening, it will look and act perfectly +to the end user. + +**Host upgrade while a plugin is running.** Plugins can be "reattached" +so that the host process can be upgraded while the plugin is still running. +This requires the host/plugin to know this is possible and daemonize +properly. `NewClient` takes a `ReattachConfig` to determine if and how to +reattach. + +## Architecture + +The HashiCorp plugin system works by launching subprocesses and communicating +over RPC (using standard `net/rpc`). A single connection is made between +any plugin and the host process, and we use a +[connection multiplexing](https://github.com/hashicorp/yamux) +library to multiplex any other connections on top. + +This architecture has a number of benefits: + + * Plugins can't crash your host process: A panic in a plugin doesn't + panic the plugin user. + + * Plugins are very easy to write: just write a Go application and `go build`. + Theoretically you could also use another language as long as it can + communicate the Go `net/rpc` protocol but this hasn't yet been tried. + + * Plugins are very easy to install: just put the binary in a location where + the host will find it (depends on the host but this library also provides + helpers), and the plugin host handles the rest. + + * Plugins can be relatively secure: The plugin only has access to the + interfaces and args given to it, not to the entire memory space of the + process. More security features are planned (see the coming soon section + below). + +## Usage + +To use the plugin system, you must take the following steps. These are +high-level steps that must be done. Examples are available in the +`examples/` directory. + + 1. Choose the interface(s) you want to expose for plugins. + + 2. For each interface, implement an implementation of that interface + that communicates over an `*rpc.Client` (from the standard `net/rpc` + package) for every function call. Likewise, implement the RPC server + struct this communicates to which is then communicating to a real, + concrete implementation. + + 3. Create a `Plugin` implementation that knows how to create the RPC + client/server for a given plugin type. + + 4. Plugin authors call `plugin.Serve` to serve a plugin from the + `main` function. + + 5. Plugin users use `plugin.Client` to launch a subprocess and request + an interface implementation over RPC. + +That's it! In practice, step 2 is the most tedious and time consuming step. +Even so, it isn't very difficult and you can see examples in the `examples/` +directory as well as throughout our various open source projects. + +For complete API documentation, see [GoDoc](https://godoc.org/github.com/hashicorp/go-plugin). + +## Roadmap + +Our plugin system is constantly evolving. As we use the plugin system for +new projects or for new features in existing projects, we constantly find +improvements we can make. + +At this point in time, the roadmap for the plugin system is: + +**Cryptographically Secure Plugins.** We'll implement signing plugins +and loading signed plugins in order to allow Vault to make use of multi-process +in a secure way. + +**Semantic Versioning.** Plugins will be able to implement a semantic version. +This plugin system will give host processes a system for constraining +versions. This is in addition to the protocol versioning already present +which is more for larger underlying changes. + +**Plugin fetching.** We will integrate with [go-getter](https://github.com/hashicorp/go-getter) +to support automatic download + install of plugins. Paired with cryptographically +secure plugins (above), we can make this a safe operation for an amazing +user experience. + +## What About Shared Libraries? + +When we started using plugins (late 2012, early 2013), plugins over RPC +were the only option since Go didn't support dynamic library loading. Today, +Go still doesn't support dynamic library loading, but they do intend to. +Since 2012, our plugin system has stabilized from millions of users using it, +and has many benefits we've come to value greatly. + +For example, we intend to use this plugin system in +[Vault](https://www.vaultproject.io), and dynamic library loading will +simply never be acceptable in Vault for security reasons. That is an extreme +example, but we believe our library system has more upsides than downsides +over dynamic library loading and since we've had it built and tested for years, +we'll likely continue to use it. + +Shared libraries have one major advantage over our system which is much +higher performance. In real world scenarios across our various tools, +we've never required any more performance out of our plugin system and it +has seen very high throughput, so this isn't a concern for us at the moment. + diff --git a/vendor/github.com/hashicorp/go-plugin/client.go b/vendor/github.com/hashicorp/go-plugin/client.go new file mode 100644 index 000000000000..b69d41b28d68 --- /dev/null +++ b/vendor/github.com/hashicorp/go-plugin/client.go @@ -0,0 +1,666 @@ +package plugin + +import ( + "bufio" + "crypto/subtle" + "crypto/tls" + "errors" + "fmt" + "hash" + "io" + "io/ioutil" + "log" + "net" + "os" + "os/exec" + "path/filepath" + "strconv" + "strings" + "sync" + "sync/atomic" + "time" + "unicode" +) + +// If this is 1, then we've called CleanupClients. This can be used +// by plugin RPC implementations to change error behavior since you +// can expected network connection errors at this point. This should be +// read by using sync/atomic. +var Killed uint32 = 0 + +// This is a slice of the "managed" clients which are cleaned up when +// calling Cleanup +var managedClients = make([]*Client, 0, 5) +var managedClientsLock sync.Mutex + +// Error types +var ( + // ErrProcessNotFound is returned when a client is instantiated to + // reattach to an existing process and it isn't found. + ErrProcessNotFound = errors.New("Reattachment process not found") + + // ErrChecksumsDoNotMatch is returned when binary's checksum doesn't match + // the one provided in the SecureConfig. + ErrChecksumsDoNotMatch = errors.New("checksums did not match") + + // ErrSecureNoChecksum is returned when an empty checksum is provided to the + // SecureConfig. + ErrSecureConfigNoChecksum = errors.New("no checksum provided") + + // ErrSecureNoHash is returned when a nil Hash object is provided to the + // SecureConfig. + ErrSecureConfigNoHash = errors.New("no hash implementation provided") + + // ErrSecureConfigAndReattach is returned when both Reattach and + // SecureConfig are set. + ErrSecureConfigAndReattach = errors.New("only one of Reattach or SecureConfig can be set") +) + +// Client handles the lifecycle of a plugin application. It launches +// plugins, connects to them, dispenses interface implementations, and handles +// killing the process. +// +// Plugin hosts should use one Client for each plugin executable. To +// dispense a plugin type, use the `Client.Client` function, and then +// cal `Dispense`. This awkward API is mostly historical but is used to split +// the client that deals with subprocess management and the client that +// does RPC management. +// +// See NewClient and ClientConfig for using a Client. +type Client struct { + config *ClientConfig + exited bool + doneLogging chan struct{} + l sync.Mutex + address net.Addr + process *os.Process + client *RPCClient +} + +// ClientConfig is the configuration used to initialize a new +// plugin client. After being used to initialize a plugin client, +// that configuration must not be modified again. +type ClientConfig struct { + // HandshakeConfig is the configuration that must match servers. + HandshakeConfig + + // Plugins are the plugins that can be consumed. + Plugins map[string]Plugin + + // One of the following must be set, but not both. + // + // Cmd is the unstarted subprocess for starting the plugin. If this is + // set, then the Client starts the plugin process on its own and connects + // to it. + // + // Reattach is configuration for reattaching to an existing plugin process + // that is already running. This isn't common. + Cmd *exec.Cmd + Reattach *ReattachConfig + + // SecureConfig is configuration for verifying the integrity of the + // executable. It can not be used with Reattach. + SecureConfig *SecureConfig + + // TLSConfig is used to enable TLS on the RPC client. + TLSConfig *tls.Config + + // Managed represents if the client should be managed by the + // plugin package or not. If true, then by calling CleanupClients, + // it will automatically be cleaned up. Otherwise, the client + // user is fully responsible for making sure to Kill all plugin + // clients. By default the client is _not_ managed. + Managed bool + + // The minimum and maximum port to use for communicating with + // the subprocess. If not set, this defaults to 10,000 and 25,000 + // respectively. + MinPort, MaxPort uint + + // StartTimeout is the timeout to wait for the plugin to say it + // has started successfully. + StartTimeout time.Duration + + // If non-nil, then the stderr of the client will be written to here + // (as well as the log). This is the original os.Stderr of the subprocess. + // This isn't the output of synced stderr. + Stderr io.Writer + + // SyncStdout, SyncStderr can be set to override the + // respective os.Std* values in the plugin. Care should be taken to + // avoid races here. If these are nil, then this will automatically be + // hooked up to os.Stdin, Stdout, and Stderr, respectively. + // + // If the default values (nil) are used, then this package will not + // sync any of these streams. + SyncStdout io.Writer + SyncStderr io.Writer +} + +// ReattachConfig is used to configure a client to reattach to an +// already-running plugin process. You can retrieve this information by +// calling ReattachConfig on Client. +type ReattachConfig struct { + Addr net.Addr + Pid int +} + +// SecureConfig is used to configure a client to verify the integrity of an +// executable before running. It does this by verifying the checksum is +// expected. Hash is used to specify the hashing method to use when checksumming +// the file. The configuration is verified by the client by calling the +// SecureConfig.Check() function. +// +// The host process should ensure the checksum was provided by a trusted and +// authoritative source. The binary should be installed in such a way that it +// can not be modified by an unauthorized user between the time of this check +// and the time of execution. +type SecureConfig struct { + Checksum []byte + Hash hash.Hash +} + +// Check takes the filepath to an executable and returns true if the checksum of +// the file matches the checksum provided in the SecureConfig. +func (s *SecureConfig) Check(filePath string) (bool, error) { + if len(s.Checksum) == 0 { + return false, ErrSecureConfigNoChecksum + } + + if s.Hash == nil { + return false, ErrSecureConfigNoHash + } + + file, err := os.Open(filePath) + if err != nil { + return false, err + } + defer file.Close() + + _, err = io.Copy(s.Hash, file) + if err != nil { + return false, err + } + + sum := s.Hash.Sum(nil) + + return subtle.ConstantTimeCompare(sum, s.Checksum) == 1, nil +} + +// This makes sure all the managed subprocesses are killed and properly +// logged. This should be called before the parent process running the +// plugins exits. +// +// This must only be called _once_. +func CleanupClients() { + // Set the killed to true so that we don't get unexpected panics + atomic.StoreUint32(&Killed, 1) + + // Kill all the managed clients in parallel and use a WaitGroup + // to wait for them all to finish up. + var wg sync.WaitGroup + managedClientsLock.Lock() + for _, client := range managedClients { + wg.Add(1) + + go func(client *Client) { + client.Kill() + wg.Done() + }(client) + } + managedClientsLock.Unlock() + + log.Println("[DEBUG] plugin: waiting for all plugin processes to complete...") + wg.Wait() +} + +// Creates a new plugin client which manages the lifecycle of an external +// plugin and gets the address for the RPC connection. +// +// The client must be cleaned up at some point by calling Kill(). If +// the client is a managed client (created with NewManagedClient) you +// can just call CleanupClients at the end of your program and they will +// be properly cleaned. +func NewClient(config *ClientConfig) (c *Client) { + if config.MinPort == 0 && config.MaxPort == 0 { + config.MinPort = 10000 + config.MaxPort = 25000 + } + + if config.StartTimeout == 0 { + config.StartTimeout = 1 * time.Minute + } + + if config.Stderr == nil { + config.Stderr = ioutil.Discard + } + + if config.SyncStdout == nil { + config.SyncStdout = ioutil.Discard + } + if config.SyncStderr == nil { + config.SyncStderr = ioutil.Discard + } + + c = &Client{config: config} + if config.Managed { + managedClientsLock.Lock() + managedClients = append(managedClients, c) + managedClientsLock.Unlock() + } + + return +} + +// Client returns an RPC client for the plugin. +// +// Subsequent calls to this will return the same RPC client. +func (c *Client) Client() (*RPCClient, error) { + addr, err := c.Start() + if err != nil { + return nil, err + } + + c.l.Lock() + defer c.l.Unlock() + + if c.client != nil { + return c.client, nil + } + + // Connect to the client + conn, err := net.Dial(addr.Network(), addr.String()) + if err != nil { + return nil, err + } + if tcpConn, ok := conn.(*net.TCPConn); ok { + // Make sure to set keep alive so that the connection doesn't die + tcpConn.SetKeepAlive(true) + } + + if c.config.TLSConfig != nil { + conn = tls.Client(conn, c.config.TLSConfig) + } + + // Create the actual RPC client + c.client, err = NewRPCClient(conn, c.config.Plugins) + if err != nil { + conn.Close() + return nil, err + } + + // Begin the stream syncing so that stdin, out, err work properly + err = c.client.SyncStreams( + c.config.SyncStdout, + c.config.SyncStderr) + if err != nil { + c.client.Close() + c.client = nil + return nil, err + } + + return c.client, nil +} + +// Tells whether or not the underlying process has exited. +func (c *Client) Exited() bool { + c.l.Lock() + defer c.l.Unlock() + return c.exited +} + +// End the executing subprocess (if it is running) and perform any cleanup +// tasks necessary such as capturing any remaining logs and so on. +// +// This method blocks until the process successfully exits. +// +// This method can safely be called multiple times. +func (c *Client) Kill() { + // Grab a lock to read some private fields. + c.l.Lock() + process := c.process + addr := c.address + doneCh := c.doneLogging + c.l.Unlock() + + // If there is no process, we never started anything. Nothing to kill. + if process == nil { + return + } + + // We need to check for address here. It is possible that the plugin + // started (process != nil) but has no address (addr == nil) if the + // plugin failed at startup. If we do have an address, we need to close + // the plugin net connections. + graceful := false + if addr != nil { + // Close the client to cleanly exit the process. + client, err := c.Client() + if err == nil { + err = client.Close() + + // If there is no error, then we attempt to wait for a graceful + // exit. If there was an error, we assume that graceful cleanup + // won't happen and just force kill. + graceful = err == nil + if err != nil { + // If there was an error just log it. We're going to force + // kill in a moment anyways. + log.Printf( + "[WARN] plugin: error closing client during Kill: %s", err) + } + } + } + + // If we're attempting a graceful exit, then we wait for a short period + // of time to allow that to happen. To wait for this we just wait on the + // doneCh which would be closed if the process exits. + if graceful { + select { + case <-doneCh: + return + case <-time.After(250 * time.Millisecond): + } + } + + // If graceful exiting failed, just kill it + process.Kill() + + // Wait for the client to finish logging so we have a complete log + <-doneCh +} + +// Starts the underlying subprocess, communicating with it to negotiate +// a port for RPC connections, and returning the address to connect via RPC. +// +// This method is safe to call multiple times. Subsequent calls have no effect. +// Once a client has been started once, it cannot be started again, even if +// it was killed. +func (c *Client) Start() (addr net.Addr, err error) { + c.l.Lock() + defer c.l.Unlock() + + if c.address != nil { + return c.address, nil + } + + // If one of cmd or reattach isn't set, then it is an error. We wrap + // this in a {} for scoping reasons, and hopeful that the escape + // analysis will pop the stock here. + { + cmdSet := c.config.Cmd != nil + attachSet := c.config.Reattach != nil + secureSet := c.config.SecureConfig != nil + if cmdSet == attachSet { + return nil, fmt.Errorf("Only one of Cmd or Reattach must be set") + } + + if secureSet && attachSet { + return nil, ErrSecureConfigAndReattach + } + } + + // Create the logging channel for when we kill + c.doneLogging = make(chan struct{}) + + if c.config.Reattach != nil { + // Verify the process still exists. If not, then it is an error + p, err := os.FindProcess(c.config.Reattach.Pid) + if err != nil { + return nil, err + } + + // Attempt to connect to the addr since on Unix systems FindProcess + // doesn't actually return an error if it can't find the process. + conn, err := net.Dial( + c.config.Reattach.Addr.Network(), + c.config.Reattach.Addr.String()) + if err != nil { + p.Kill() + return nil, ErrProcessNotFound + } + conn.Close() + + // Goroutine to mark exit status + go func(pid int) { + // Wait for the process to die + pidWait(pid) + + // Log so we can see it + log.Printf("[DEBUG] plugin: reattached plugin process exited\n") + + // Mark it + c.l.Lock() + defer c.l.Unlock() + c.exited = true + + // Close the logging channel since that doesn't work on reattach + close(c.doneLogging) + }(p.Pid) + + // Set the address and process + c.address = c.config.Reattach.Addr + c.process = p + + return c.address, nil + } + + env := []string{ + fmt.Sprintf("%s=%s", c.config.MagicCookieKey, c.config.MagicCookieValue), + fmt.Sprintf("PLUGIN_MIN_PORT=%d", c.config.MinPort), + fmt.Sprintf("PLUGIN_MAX_PORT=%d", c.config.MaxPort), + } + + stdout_r, stdout_w := io.Pipe() + stderr_r, stderr_w := io.Pipe() + + cmd := c.config.Cmd + cmd.Env = append(cmd.Env, os.Environ()...) + cmd.Env = append(cmd.Env, env...) + cmd.Stdin = os.Stdin + cmd.Stderr = stderr_w + cmd.Stdout = stdout_w + + if c.config.SecureConfig != nil { + if ok, err := c.config.SecureConfig.Check(cmd.Path); err != nil { + return nil, fmt.Errorf("error verifying checksum: %s", err) + } else if !ok { + return nil, ErrChecksumsDoNotMatch + } + } + + log.Printf("[DEBUG] plugin: starting plugin: %s %#v", cmd.Path, cmd.Args) + err = cmd.Start() + if err != nil { + return + } + + // Set the process + c.process = cmd.Process + + // Make sure the command is properly cleaned up if there is an error + defer func() { + r := recover() + + if err != nil || r != nil { + cmd.Process.Kill() + } + + if r != nil { + panic(r) + } + }() + + // Start goroutine to wait for process to exit + exitCh := make(chan struct{}) + go func() { + // Make sure we close the write end of our stderr/stdout so + // that the readers send EOF properly. + defer stderr_w.Close() + defer stdout_w.Close() + + // Wait for the command to end. + cmd.Wait() + + // Log and make sure to flush the logs write away + log.Printf("[DEBUG] plugin: %s: plugin process exited\n", cmd.Path) + os.Stderr.Sync() + + // Mark that we exited + close(exitCh) + + // Set that we exited, which takes a lock + c.l.Lock() + defer c.l.Unlock() + c.exited = true + }() + + // Start goroutine that logs the stderr + go c.logStderr(stderr_r) + + // Start a goroutine that is going to be reading the lines + // out of stdout + linesCh := make(chan []byte) + go func() { + defer close(linesCh) + + buf := bufio.NewReader(stdout_r) + for { + line, err := buf.ReadBytes('\n') + if line != nil { + linesCh <- line + } + + if err == io.EOF { + return + } + } + }() + + // Make sure after we exit we read the lines from stdout forever + // so they don't block since it is an io.Pipe + defer func() { + go func() { + for _ = range linesCh { + } + }() + }() + + // Some channels for the next step + timeout := time.After(c.config.StartTimeout) + + // Start looking for the address + log.Printf("[DEBUG] plugin: waiting for RPC address for: %s", cmd.Path) + select { + case <-timeout: + err = errors.New("timeout while waiting for plugin to start") + case <-exitCh: + err = errors.New("plugin exited before we could connect") + case lineBytes := <-linesCh: + // Trim the line and split by "|" in order to get the parts of + // the output. + line := strings.TrimSpace(string(lineBytes)) + parts := strings.SplitN(line, "|", 4) + if len(parts) < 4 { + err = fmt.Errorf( + "Unrecognized remote plugin message: %s\n\n"+ + "This usually means that the plugin is either invalid or simply\n"+ + "needs to be recompiled to support the latest protocol.", line) + return + } + + // Check the core protocol. Wrapped in a {} for scoping. + { + var coreProtocol int64 + coreProtocol, err = strconv.ParseInt(parts[0], 10, 0) + if err != nil { + err = fmt.Errorf("Error parsing core protocol version: %s", err) + return + } + + if int(coreProtocol) != CoreProtocolVersion { + err = fmt.Errorf("Incompatible core API version with plugin. "+ + "Plugin version: %s, Ours: %d\n\n"+ + "To fix this, the plugin usually only needs to be recompiled.\n"+ + "Please report this to the plugin author.", parts[0], CoreProtocolVersion) + return + } + } + + // Parse the protocol version + var protocol int64 + protocol, err = strconv.ParseInt(parts[1], 10, 0) + if err != nil { + err = fmt.Errorf("Error parsing protocol version: %s", err) + return + } + + // Test the API version + if uint(protocol) != c.config.ProtocolVersion { + err = fmt.Errorf("Incompatible API version with plugin. "+ + "Plugin version: %s, Ours: %d", parts[1], c.config.ProtocolVersion) + return + } + + switch parts[2] { + case "tcp": + addr, err = net.ResolveTCPAddr("tcp", parts[3]) + case "unix": + addr, err = net.ResolveUnixAddr("unix", parts[3]) + default: + err = fmt.Errorf("Unknown address type: %s", parts[3]) + } + } + + c.address = addr + return +} + +// ReattachConfig returns the information that must be provided to NewClient +// to reattach to the plugin process that this client started. This is +// useful for plugins that detach from their parent process. +// +// If this returns nil then the process hasn't been started yet. Please +// call Start or Client before calling this. +func (c *Client) ReattachConfig() *ReattachConfig { + c.l.Lock() + defer c.l.Unlock() + + if c.address == nil { + return nil + } + + if c.config.Cmd != nil && c.config.Cmd.Process == nil { + return nil + } + + // If we connected via reattach, just return the information as-is + if c.config.Reattach != nil { + return c.config.Reattach + } + + return &ReattachConfig{ + Addr: c.address, + Pid: c.config.Cmd.Process.Pid, + } +} + +func (c *Client) logStderr(r io.Reader) { + bufR := bufio.NewReader(r) + for { + line, err := bufR.ReadString('\n') + if line != "" { + c.config.Stderr.Write([]byte(line)) + + line = strings.TrimRightFunc(line, unicode.IsSpace) + log.Printf("[DEBUG] plugin: %s: %s", filepath.Base(c.config.Cmd.Path), line) + } + + if err == io.EOF { + break + } + } + + // Flag that we've completed logging for others + close(c.doneLogging) +} diff --git a/vendor/github.com/hashicorp/go-plugin/discover.go b/vendor/github.com/hashicorp/go-plugin/discover.go new file mode 100644 index 000000000000..d22c566ed506 --- /dev/null +++ b/vendor/github.com/hashicorp/go-plugin/discover.go @@ -0,0 +1,28 @@ +package plugin + +import ( + "path/filepath" +) + +// Discover discovers plugins that are in a given directory. +// +// The directory doesn't need to be absolute. For example, "." will work fine. +// +// This currently assumes any file matching the glob is a plugin. +// In the future this may be smarter about checking that a file is +// executable and so on. +// +// TODO: test +func Discover(glob, dir string) ([]string, error) { + var err error + + // Make the directory absolute if it isn't already + if !filepath.IsAbs(dir) { + dir, err = filepath.Abs(dir) + if err != nil { + return nil, err + } + } + + return filepath.Glob(filepath.Join(dir, glob)) +} diff --git a/vendor/github.com/hashicorp/go-plugin/error.go b/vendor/github.com/hashicorp/go-plugin/error.go new file mode 100644 index 000000000000..22a7baa6a0d8 --- /dev/null +++ b/vendor/github.com/hashicorp/go-plugin/error.go @@ -0,0 +1,24 @@ +package plugin + +// This is a type that wraps error types so that they can be messaged +// across RPC channels. Since "error" is an interface, we can't always +// gob-encode the underlying structure. This is a valid error interface +// implementer that we will push across. +type BasicError struct { + Message string +} + +// NewBasicError is used to create a BasicError. +// +// err is allowed to be nil. +func NewBasicError(err error) *BasicError { + if err == nil { + return nil + } + + return &BasicError{err.Error()} +} + +func (e *BasicError) Error() string { + return e.Message +} diff --git a/vendor/github.com/hashicorp/go-plugin/mux_broker.go b/vendor/github.com/hashicorp/go-plugin/mux_broker.go new file mode 100644 index 000000000000..01c45ad7c682 --- /dev/null +++ b/vendor/github.com/hashicorp/go-plugin/mux_broker.go @@ -0,0 +1,204 @@ +package plugin + +import ( + "encoding/binary" + "fmt" + "log" + "net" + "sync" + "sync/atomic" + "time" + + "github.com/hashicorp/yamux" +) + +// MuxBroker is responsible for brokering multiplexed connections by unique ID. +// +// It is used by plugins to multiplex multiple RPC connections and data +// streams on top of a single connection between the plugin process and the +// host process. +// +// This allows a plugin to request a channel with a specific ID to connect to +// or accept a connection from, and the broker handles the details of +// holding these channels open while they're being negotiated. +// +// The Plugin interface has access to these for both Server and Client. +// The broker can be used by either (optionally) to reserve and connect to +// new multiplexed streams. This is useful for complex args and return values, +// or anything else you might need a data stream for. +type MuxBroker struct { + nextId uint32 + session *yamux.Session + streams map[uint32]*muxBrokerPending + + sync.Mutex +} + +type muxBrokerPending struct { + ch chan net.Conn + doneCh chan struct{} +} + +func newMuxBroker(s *yamux.Session) *MuxBroker { + return &MuxBroker{ + session: s, + streams: make(map[uint32]*muxBrokerPending), + } +} + +// Accept accepts a connection by ID. +// +// This should not be called multiple times with the same ID at one time. +func (m *MuxBroker) Accept(id uint32) (net.Conn, error) { + var c net.Conn + p := m.getStream(id) + select { + case c = <-p.ch: + close(p.doneCh) + case <-time.After(5 * time.Second): + m.Lock() + defer m.Unlock() + delete(m.streams, id) + + return nil, fmt.Errorf("timeout waiting for accept") + } + + // Ack our connection + if err := binary.Write(c, binary.LittleEndian, id); err != nil { + c.Close() + return nil, err + } + + return c, nil +} + +// AcceptAndServe is used to accept a specific stream ID and immediately +// serve an RPC server on that stream ID. This is used to easily serve +// complex arguments. +// +// The served interface is always registered to the "Plugin" name. +func (m *MuxBroker) AcceptAndServe(id uint32, v interface{}) { + conn, err := m.Accept(id) + if err != nil { + log.Printf("[ERR] plugin: plugin acceptAndServe error: %s", err) + return + } + + serve(conn, "Plugin", v) +} + +// Close closes the connection and all sub-connections. +func (m *MuxBroker) Close() error { + return m.session.Close() +} + +// Dial opens a connection by ID. +func (m *MuxBroker) Dial(id uint32) (net.Conn, error) { + // Open the stream + stream, err := m.session.OpenStream() + if err != nil { + return nil, err + } + + // Write the stream ID onto the wire. + if err := binary.Write(stream, binary.LittleEndian, id); err != nil { + stream.Close() + return nil, err + } + + // Read the ack that we connected. Then we're off! + var ack uint32 + if err := binary.Read(stream, binary.LittleEndian, &ack); err != nil { + stream.Close() + return nil, err + } + if ack != id { + stream.Close() + return nil, fmt.Errorf("bad ack: %d (expected %d)", ack, id) + } + + return stream, nil +} + +// NextId returns a unique ID to use next. +// +// It is possible for very long-running plugin hosts to wrap this value, +// though it would require a very large amount of RPC calls. In practice +// we've never seen it happen. +func (m *MuxBroker) NextId() uint32 { + return atomic.AddUint32(&m.nextId, 1) +} + +// Run starts the brokering and should be executed in a goroutine, since it +// blocks forever, or until the session closes. +// +// Uses of MuxBroker never need to call this. It is called internally by +// the plugin host/client. +func (m *MuxBroker) Run() { + for { + stream, err := m.session.AcceptStream() + if err != nil { + // Once we receive an error, just exit + break + } + + // Read the stream ID from the stream + var id uint32 + if err := binary.Read(stream, binary.LittleEndian, &id); err != nil { + stream.Close() + continue + } + + // Initialize the waiter + p := m.getStream(id) + select { + case p.ch <- stream: + default: + } + + // Wait for a timeout + go m.timeoutWait(id, p) + } +} + +func (m *MuxBroker) getStream(id uint32) *muxBrokerPending { + m.Lock() + defer m.Unlock() + + p, ok := m.streams[id] + if ok { + return p + } + + m.streams[id] = &muxBrokerPending{ + ch: make(chan net.Conn, 1), + doneCh: make(chan struct{}), + } + return m.streams[id] +} + +func (m *MuxBroker) timeoutWait(id uint32, p *muxBrokerPending) { + // Wait for the stream to either be picked up and connected, or + // for a timeout. + timeout := false + select { + case <-p.doneCh: + case <-time.After(5 * time.Second): + timeout = true + } + + m.Lock() + defer m.Unlock() + + // Delete the stream so no one else can grab it + delete(m.streams, id) + + // If we timed out, then check if we have a channel in the buffer, + // and if so, close it. + if timeout { + select { + case s := <-p.ch: + s.Close() + } + } +} diff --git a/vendor/github.com/hashicorp/go-plugin/plugin.go b/vendor/github.com/hashicorp/go-plugin/plugin.go new file mode 100644 index 000000000000..37c8fd653f90 --- /dev/null +++ b/vendor/github.com/hashicorp/go-plugin/plugin.go @@ -0,0 +1,25 @@ +// The plugin package exposes functions and helpers for communicating to +// plugins which are implemented as standalone binary applications. +// +// plugin.Client fully manages the lifecycle of executing the application, +// connecting to it, and returning the RPC client for dispensing plugins. +// +// plugin.Serve fully manages listeners to expose an RPC server from a binary +// that plugin.Client can connect to. +package plugin + +import ( + "net/rpc" +) + +// Plugin is the interface that is implemented to serve/connect to an +// inteface implementation. +type Plugin interface { + // Server should return the RPC server compatible struct to serve + // the methods that the Client calls over net/rpc. + Server(*MuxBroker) (interface{}, error) + + // Client returns an interface implementation for the plugin you're + // serving that communicates to the server end of the plugin. + Client(*MuxBroker, *rpc.Client) (interface{}, error) +} diff --git a/vendor/github.com/hashicorp/go-plugin/process.go b/vendor/github.com/hashicorp/go-plugin/process.go new file mode 100644 index 000000000000..88c999a580d3 --- /dev/null +++ b/vendor/github.com/hashicorp/go-plugin/process.go @@ -0,0 +1,24 @@ +package plugin + +import ( + "time" +) + +// pidAlive checks whether a pid is alive. +func pidAlive(pid int) bool { + return _pidAlive(pid) +} + +// pidWait blocks for a process to exit. +func pidWait(pid int) error { + ticker := time.NewTicker(1 * time.Second) + defer ticker.Stop() + + for range ticker.C { + if !pidAlive(pid) { + break + } + } + + return nil +} diff --git a/vendor/github.com/hashicorp/go-plugin/process_posix.go b/vendor/github.com/hashicorp/go-plugin/process_posix.go new file mode 100644 index 000000000000..70ba546bf6dd --- /dev/null +++ b/vendor/github.com/hashicorp/go-plugin/process_posix.go @@ -0,0 +1,19 @@ +// +build !windows + +package plugin + +import ( + "os" + "syscall" +) + +// _pidAlive tests whether a process is alive or not by sending it Signal 0, +// since Go otherwise has no way to test this. +func _pidAlive(pid int) bool { + proc, err := os.FindProcess(pid) + if err == nil { + err = proc.Signal(syscall.Signal(0)) + } + + return err == nil +} diff --git a/vendor/github.com/hashicorp/go-plugin/process_windows.go b/vendor/github.com/hashicorp/go-plugin/process_windows.go new file mode 100644 index 000000000000..9f7b0180901f --- /dev/null +++ b/vendor/github.com/hashicorp/go-plugin/process_windows.go @@ -0,0 +1,29 @@ +package plugin + +import ( + "syscall" +) + +const ( + // Weird name but matches the MSDN docs + exit_STILL_ACTIVE = 259 + + processDesiredAccess = syscall.STANDARD_RIGHTS_READ | + syscall.PROCESS_QUERY_INFORMATION | + syscall.SYNCHRONIZE +) + +// _pidAlive tests whether a process is alive or not +func _pidAlive(pid int) bool { + h, err := syscall.OpenProcess(processDesiredAccess, false, uint32(pid)) + if err != nil { + return false + } + + var ec uint32 + if e := syscall.GetExitCodeProcess(h, &ec); e != nil { + return false + } + + return ec == exit_STILL_ACTIVE +} diff --git a/vendor/github.com/hashicorp/go-plugin/rpc_client.go b/vendor/github.com/hashicorp/go-plugin/rpc_client.go new file mode 100644 index 000000000000..29f9bf063e7b --- /dev/null +++ b/vendor/github.com/hashicorp/go-plugin/rpc_client.go @@ -0,0 +1,123 @@ +package plugin + +import ( + "fmt" + "io" + "net" + "net/rpc" + + "github.com/hashicorp/yamux" +) + +// RPCClient connects to an RPCServer over net/rpc to dispense plugin types. +type RPCClient struct { + broker *MuxBroker + control *rpc.Client + plugins map[string]Plugin + + // These are the streams used for the various stdout/err overrides + stdout, stderr net.Conn +} + +// NewRPCClient creates a client from an already-open connection-like value. +// Dial is typically used instead. +func NewRPCClient(conn io.ReadWriteCloser, plugins map[string]Plugin) (*RPCClient, error) { + // Create the yamux client so we can multiplex + mux, err := yamux.Client(conn, nil) + if err != nil { + conn.Close() + return nil, err + } + + // Connect to the control stream. + control, err := mux.Open() + if err != nil { + mux.Close() + return nil, err + } + + // Connect stdout, stderr streams + stdstream := make([]net.Conn, 2) + for i, _ := range stdstream { + stdstream[i], err = mux.Open() + if err != nil { + mux.Close() + return nil, err + } + } + + // Create the broker and start it up + broker := newMuxBroker(mux) + go broker.Run() + + // Build the client using our broker and control channel. + return &RPCClient{ + broker: broker, + control: rpc.NewClient(control), + plugins: plugins, + stdout: stdstream[0], + stderr: stdstream[1], + }, nil +} + +// SyncStreams should be called to enable syncing of stdout, +// stderr with the plugin. +// +// This will return immediately and the syncing will continue to happen +// in the background. You do not need to launch this in a goroutine itself. +// +// This should never be called multiple times. +func (c *RPCClient) SyncStreams(stdout io.Writer, stderr io.Writer) error { + go copyStream("stdout", stdout, c.stdout) + go copyStream("stderr", stderr, c.stderr) + return nil +} + +// Close closes the connection. The client is no longer usable after this +// is called. +func (c *RPCClient) Close() error { + // Call the control channel and ask it to gracefully exit. If this + // errors, then we save it so that we always return an error but we + // want to try to close the other channels anyways. + var empty struct{} + returnErr := c.control.Call("Control.Quit", true, &empty) + + // Close the other streams we have + if err := c.control.Close(); err != nil { + return err + } + if err := c.stdout.Close(); err != nil { + return err + } + if err := c.stderr.Close(); err != nil { + return err + } + if err := c.broker.Close(); err != nil { + return err + } + + // Return back the error we got from Control.Quit. This is very important + // since we MUST return non-nil error if this fails so that Client.Kill + // will properly try a process.Kill. + return returnErr +} + +func (c *RPCClient) Dispense(name string) (interface{}, error) { + p, ok := c.plugins[name] + if !ok { + return nil, fmt.Errorf("unknown plugin type: %s", name) + } + + var id uint32 + if err := c.control.Call( + "Dispenser.Dispense", name, &id); err != nil { + return nil, err + } + + conn, err := c.broker.Dial(id) + if err != nil { + return nil, err + } + + return p.Client(c.broker, rpc.NewClient(conn)) +} diff --git a/vendor/github.com/hashicorp/go-plugin/rpc_server.go b/vendor/github.com/hashicorp/go-plugin/rpc_server.go new file mode 100644 index 000000000000..3984dc891ba6 --- /dev/null +++ b/vendor/github.com/hashicorp/go-plugin/rpc_server.go @@ -0,0 +1,185 @@ +package plugin + +import ( + "errors" + "fmt" + "io" + "log" + "net" + "net/rpc" + "sync" + + "github.com/hashicorp/yamux" +) + +// RPCServer listens for network connections and then dispenses interface +// implementations over net/rpc. +// +// After setting the fields below, they shouldn't be read again directly +// from the structure which may be reading/writing them concurrently. +type RPCServer struct { + Plugins map[string]Plugin + + // Stdout, Stderr are what this server will use instead of the + // normal stdin/out/err. This is because due to the multi-process nature + // of our plugin system, we can't use the normal process values so we + // make our own custom one we pipe across. + Stdout io.Reader + Stderr io.Reader + + // DoneCh should be set to a non-nil channel that will be closed + // when the control requests the RPC server to end. + DoneCh chan<- struct{} + + lock sync.Mutex +} + +// Accept accepts connections on a listener and serves requests for +// each incoming connection. Accept blocks; the caller typically invokes +// it in a go statement. +func (s *RPCServer) Accept(lis net.Listener) { + for { + conn, err := lis.Accept() + if err != nil { + log.Printf("[ERR] plugin: plugin server: %s", err) + return + } + + go s.ServeConn(conn) + } +} + +// ServeConn runs a single connection. +// +// ServeConn blocks, serving the connection until the client hangs up. +func (s *RPCServer) ServeConn(conn io.ReadWriteCloser) { + // First create the yamux server to wrap this connection + mux, err := yamux.Server(conn, nil) + if err != nil { + conn.Close() + log.Printf("[ERR] plugin: error creating yamux server: %s", err) + return + } + + // Accept the control connection + control, err := mux.Accept() + if err != nil { + mux.Close() + if err != io.EOF { + log.Printf("[ERR] plugin: error accepting control connection: %s", err) + } + + return + } + + // Connect the stdstreams (in, out, err) + stdstream := make([]net.Conn, 2) + for i, _ := range stdstream { + stdstream[i], err = mux.Accept() + if err != nil { + mux.Close() + log.Printf("[ERR] plugin: accepting stream %d: %s", i, err) + return + } + } + + // Copy std streams out to the proper place + go copyStream("stdout", stdstream[0], s.Stdout) + go copyStream("stderr", stdstream[1], s.Stderr) + + // Create the broker and start it up + broker := newMuxBroker(mux) + go broker.Run() + + // Use the control connection to build the dispenser and serve the + // connection. + server := rpc.NewServer() + server.RegisterName("Control", &controlServer{ + server: s, + }) + server.RegisterName("Dispenser", &dispenseServer{ + broker: broker, + plugins: s.Plugins, + }) + server.ServeConn(control) +} + +// done is called internally by the control server to trigger the +// doneCh to close which is listened to by the main process to cleanly +// exit. +func (s *RPCServer) done() { + s.lock.Lock() + defer s.lock.Unlock() + + if s.DoneCh != nil { + close(s.DoneCh) + s.DoneCh = nil + } +} + +// dispenseServer dispenses variousinterface implementations for Terraform. +type controlServer struct { + server *RPCServer +} + +func (c *controlServer) Quit( + null bool, response *struct{}) error { + // End the server + c.server.done() + + // Always return true + *response = struct{}{} + + return nil +} + +// dispenseServer dispenses variousinterface implementations for Terraform. +type dispenseServer struct { + broker *MuxBroker + plugins map[string]Plugin +} + +func (d *dispenseServer) Dispense( + name string, response *uint32) error { + // Find the function to create this implementation + p, ok := d.plugins[name] + if !ok { + return fmt.Errorf("unknown plugin type: %s", name) + } + + // Create the implementation first so we know if there is an error. + impl, err := p.Server(d.broker) + if err != nil { + // We turn the error into an errors error so that it works across RPC + return errors.New(err.Error()) + } + + // Reserve an ID for our implementation + id := d.broker.NextId() + *response = id + + // Run the rest in a goroutine since it can only happen once this RPC + // call returns. We wait for a connection for the plugin implementation + // and serve it. + go func() { + conn, err := d.broker.Accept(id) + if err != nil { + log.Printf("[ERR] go-plugin: plugin dispense error: %s: %s", name, err) + return + } + + serve(conn, "Plugin", impl) + }() + + return nil +} + +func serve(conn io.ReadWriteCloser, name string, v interface{}) { + server := rpc.NewServer() + if err := server.RegisterName(name, v); err != nil { + log.Printf("[ERR] go-plugin: plugin dispense error: %s", err) + return + } + + server.ServeConn(conn) +} diff --git a/vendor/github.com/hashicorp/go-plugin/server.go b/vendor/github.com/hashicorp/go-plugin/server.go new file mode 100644 index 000000000000..782a4e119d89 --- /dev/null +++ b/vendor/github.com/hashicorp/go-plugin/server.go @@ -0,0 +1,235 @@ +package plugin + +import ( + "crypto/tls" + "errors" + "fmt" + "io/ioutil" + "log" + "net" + "os" + "os/signal" + "runtime" + "strconv" + "sync/atomic" +) + +// CoreProtocolVersion is the ProtocolVersion of the plugin system itself. +// We will increment this whenever we change any protocol behavior. This +// will invalidate any prior plugins but will at least allow us to iterate +// on the core in a safe way. We will do our best to do this very +// infrequently. +const CoreProtocolVersion = 1 + +// HandshakeConfig is the configuration used by client and servers to +// handshake before starting a plugin connection. This is embedded by +// both ServeConfig and ClientConfig. +// +// In practice, the plugin host creates a HandshakeConfig that is exported +// and plugins then can easily consume it. +type HandshakeConfig struct { + // ProtocolVersion is the version that clients must match on to + // agree they can communicate. This should match the ProtocolVersion + // set on ClientConfig when using a plugin. + ProtocolVersion uint + + // MagicCookieKey and value are used as a very basic verification + // that a plugin is intended to be launched. This is not a security + // measure, just a UX feature. If the magic cookie doesn't match, + // we show human-friendly output. + MagicCookieKey string + MagicCookieValue string +} + +// ServeConfig configures what sorts of plugins are served. +type ServeConfig struct { + // HandshakeConfig is the configuration that must match clients. + HandshakeConfig + + // Plugins are the plugins that are served. + Plugins map[string]Plugin + + // TLSProvider is a function that returns a configured tls.Config. + TLSProvider func() (*tls.Config, error) +} + +// Serve serves the plugins given by ServeConfig. +// +// Serve doesn't return until the plugin is done being executed. Any +// errors will be outputted to the log. +// +// This is the method that plugins should call in their main() functions. +func Serve(opts *ServeConfig) { + // Validate the handshake config + if opts.MagicCookieKey == "" || opts.MagicCookieValue == "" { + fmt.Fprintf(os.Stderr, + "Misconfigured ServeConfig given to serve this plugin: no magic cookie\n"+ + "key or value was set. Please notify the plugin author and report\n"+ + "this as a bug.\n") + os.Exit(1) + } + + // First check the cookie + if os.Getenv(opts.MagicCookieKey) != opts.MagicCookieValue { + fmt.Fprintf(os.Stderr, + "This binary is a plugin. These are not meant to be executed directly.\n"+ + "Please execute the program that consumes these plugins, which will\n"+ + "load any plugins automatically\n") + os.Exit(1) + } + + // Logging goes to the original stderr + log.SetOutput(os.Stderr) + + // Create our new stdout, stderr files. These will override our built-in + // stdout/stderr so that it works across the stream boundary. + stdout_r, stdout_w, err := os.Pipe() + if err != nil { + fmt.Fprintf(os.Stderr, "Error preparing plugin: %s\n", err) + os.Exit(1) + } + stderr_r, stderr_w, err := os.Pipe() + if err != nil { + fmt.Fprintf(os.Stderr, "Error preparing plugin: %s\n", err) + os.Exit(1) + } + + // Register a listener so we can accept a connection + listener, err := serverListener() + if err != nil { + log.Printf("[ERR] plugin: plugin init: %s", err) + return + } + + if opts.TLSProvider != nil { + tlsConfig, err := opts.TLSProvider() + if err != nil { + log.Printf("[ERR] plugin: plugin tls init: %s", err) + return + } + listener = tls.NewListener(listener, tlsConfig) + } + defer listener.Close() + + // Create the channel to tell us when we're done + doneCh := make(chan struct{}) + + // Create the RPC server to dispense + server := &RPCServer{ + Plugins: opts.Plugins, + Stdout: stdout_r, + Stderr: stderr_r, + DoneCh: doneCh, + } + + // Output the address and service name to stdout so that core can bring it up. + log.Printf("[DEBUG] plugin: plugin address: %s %s\n", + listener.Addr().Network(), listener.Addr().String()) + fmt.Printf("%d|%d|%s|%s\n", + CoreProtocolVersion, + opts.ProtocolVersion, + listener.Addr().Network(), + listener.Addr().String()) + os.Stdout.Sync() + + // Eat the interrupts + ch := make(chan os.Signal, 1) + signal.Notify(ch, os.Interrupt) + go func() { + var count int32 = 0 + for { + <-ch + newCount := atomic.AddInt32(&count, 1) + log.Printf( + "[DEBUG] plugin: received interrupt signal (count: %d). Ignoring.", + newCount) + } + }() + + // Set our new out, err + os.Stdout = stdout_w + os.Stderr = stderr_w + + // Serve + go server.Accept(listener) + + // Wait for the graceful exit + <-doneCh +} + +func serverListener() (net.Listener, error) { + if runtime.GOOS == "windows" { + return serverListener_tcp() + } + + return serverListener_unix() +} + +func serverListener_tcp() (net.Listener, error) { + minPort, err := strconv.ParseInt(os.Getenv("PLUGIN_MIN_PORT"), 10, 32) + if err != nil { + return nil, err + } + + maxPort, err := strconv.ParseInt(os.Getenv("PLUGIN_MAX_PORT"), 10, 32) + if err != nil { + return nil, err + } + + for port := minPort; port <= maxPort; port++ { + address := fmt.Sprintf("127.0.0.1:%d", port) + listener, err := net.Listen("tcp", address) + if err == nil { + return listener, nil + } + } + + return nil, errors.New("Couldn't bind plugin TCP listener") +} + +func serverListener_unix() (net.Listener, error) { + tf, err := ioutil.TempFile("", "plugin") + if err != nil { + return nil, err + } + path := tf.Name() + + // Close the file and remove it because it has to not exist for + // the domain socket. + if err := tf.Close(); err != nil { + return nil, err + } + if err := os.Remove(path); err != nil { + return nil, err + } + + l, err := net.Listen("unix", path) + if err != nil { + return nil, err + } + + // Wrap the listener in rmListener so that the Unix domain socket file + // is removed on close. + return &rmListener{ + Listener: l, + Path: path, + }, nil +} + +// rmListener is an implementation of net.Listener that forwards most +// calls to the listener but also removes a file as part of the close. We +// use this to cleanup the unix domain socket on close. +type rmListener struct { + net.Listener + Path string +} + +func (l *rmListener) Close() error { + // Close the listener itself + if err := l.Listener.Close(); err != nil { + return err + } + + // Remove the file + return os.Remove(l.Path) +} diff --git a/vendor/github.com/hashicorp/go-plugin/server_mux.go b/vendor/github.com/hashicorp/go-plugin/server_mux.go new file mode 100644 index 000000000000..033079ea0fc5 --- /dev/null +++ b/vendor/github.com/hashicorp/go-plugin/server_mux.go @@ -0,0 +1,31 @@ +package plugin + +import ( + "fmt" + "os" +) + +// ServeMuxMap is the type that is used to configure ServeMux +type ServeMuxMap map[string]*ServeConfig + +// ServeMux is like Serve, but serves multiple types of plugins determined +// by the argument given on the command-line. +// +// This command doesn't return until the plugin is done being executed. Any +// errors are logged or output to stderr. +func ServeMux(m ServeMuxMap) { + if len(os.Args) != 2 { + fmt.Fprintf(os.Stderr, + "Invoked improperly. This is an internal command that shouldn't\n"+ + "be manually invoked.\n") + os.Exit(1) + } + + opts, ok := m[os.Args[1]] + if !ok { + fmt.Fprintf(os.Stderr, "Unknown plugin: %s\n", os.Args[1]) + os.Exit(1) + } + + Serve(opts) +} diff --git a/vendor/github.com/hashicorp/go-plugin/stream.go b/vendor/github.com/hashicorp/go-plugin/stream.go new file mode 100644 index 000000000000..1d547aaaab3f --- /dev/null +++ b/vendor/github.com/hashicorp/go-plugin/stream.go @@ -0,0 +1,18 @@ +package plugin + +import ( + "io" + "log" +) + +func copyStream(name string, dst io.Writer, src io.Reader) { + if src == nil { + panic(name + ": src is nil") + } + if dst == nil { + panic(name + ": dst is nil") + } + if _, err := io.Copy(dst, src); err != nil && err != io.EOF { + log.Printf("[ERR] plugin: stream copy '%s' error: %s", name, err) + } +} diff --git a/vendor/github.com/hashicorp/go-plugin/testing.go b/vendor/github.com/hashicorp/go-plugin/testing.go new file mode 100644 index 000000000000..9086a1b45f60 --- /dev/null +++ b/vendor/github.com/hashicorp/go-plugin/testing.go @@ -0,0 +1,76 @@ +package plugin + +import ( + "bytes" + "net" + "net/rpc" + "testing" +) + +// The testing file contains test helpers that you can use outside of +// this package for making it easier to test plugins themselves. + +// TestConn is a helper function for returning a client and server +// net.Conn connected to each other. +func TestConn(t *testing.T) (net.Conn, net.Conn) { + // Listen to any local port. This listener will be closed + // after a single connection is established. + l, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("err: %s", err) + } + + // Start a goroutine to accept our client connection + var serverConn net.Conn + doneCh := make(chan struct{}) + go func() { + defer close(doneCh) + defer l.Close() + var err error + serverConn, err = l.Accept() + if err != nil { + t.Fatalf("err: %s", err) + } + }() + + // Connect to the server + clientConn, err := net.Dial("tcp", l.Addr().String()) + if err != nil { + t.Fatalf("err: %s", err) + } + + // Wait for the server side to acknowledge it has connected + <-doneCh + + return clientConn, serverConn +} + +// TestRPCConn returns a rpc client and server connected to each other. +func TestRPCConn(t *testing.T) (*rpc.Client, *rpc.Server) { + clientConn, serverConn := TestConn(t) + + server := rpc.NewServer() + go server.ServeConn(serverConn) + + client := rpc.NewClient(clientConn) + return client, server +} + +// TestPluginRPCConn returns a plugin RPC client and server that are connected +// together and configured. +func TestPluginRPCConn(t *testing.T, ps map[string]Plugin) (*RPCClient, *RPCServer) { + // Create two net.Conns we can use to shuttle our control connection + clientConn, serverConn := TestConn(t) + + // Start up the server + server := &RPCServer{Plugins: ps, Stdout: new(bytes.Buffer), Stderr: new(bytes.Buffer)} + go server.ServeConn(serverConn) + + // Connect the client to the server + client, err := NewRPCClient(clientConn, ps) + if err != nil { + t.Fatalf("err: %s", err) + } + + return client, server +} diff --git a/vendor/vendor.json b/vendor/vendor.json index ee93f5e236b9..ffcc7c5d42fc 100644 --- a/vendor/vendor.json +++ b/vendor/vendor.json @@ -804,6 +804,12 @@ "revision": "ed905158d87462226a13fe39ddf685ea65f1c11f", "revisionTime": "2016-12-16T18:43:04Z" }, + { + "checksumSHA1": "FOLPOFo4xuUaErsL99EC8azEUjw=", + "path": "github.com/hashicorp/go-plugin", + "revision": "b6691c5cfe7f0ec984114b056889cc90e51e38d0", + "revisionTime": "2017-04-12T21:16:38Z" + }, { "checksumSHA1": "ErJHGU6AVPZM9yoY/xV11TwSjQs=", "path": "github.com/hashicorp/go-retryablehttp", From 03e2bcbc7902128a1e5a0181fe2d5beb1ba5fe8f Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Wed, 12 Apr 2017 16:41:06 -0700 Subject: [PATCH 073/152] Update Type() to return an error --- builtin/logical/database/backend.go | 2 +- builtin/logical/database/dbplugin/client.go | 10 +++++----- .../logical/database/dbplugin/databasemiddleware.go | 4 ++-- builtin/logical/database/dbplugin/plugin.go | 12 +++++++++--- builtin/logical/database/dbplugin/plugin_test.go | 12 ++++++------ builtin/logical/database/dbplugin/server.go | 5 +++-- 6 files changed, 26 insertions(+), 19 deletions(-) diff --git a/builtin/logical/database/backend.go b/builtin/logical/database/backend.go index 618ffac6f809..c8f9ad85411b 100644 --- a/builtin/logical/database/backend.go +++ b/builtin/logical/database/backend.go @@ -162,5 +162,5 @@ as secret backends, including but not limited to: cassandra, msslq, mysql, postgres After mounting this backend, configure it using the endpoints within -the "database/dbs/" path. +the "database/config/" path. ` diff --git a/builtin/logical/database/dbplugin/client.go b/builtin/logical/database/dbplugin/client.go index 5bdc3a01a0db..93db86595a1c 100644 --- a/builtin/logical/database/dbplugin/client.go +++ b/builtin/logical/database/dbplugin/client.go @@ -52,10 +52,11 @@ func newPluginClient(sys pluginutil.Wrapper, pluginRunner *pluginutil.PluginRunn return nil, err } - // We should have a Greeter now! This feels like a normal interface + // We should have a database type now. This feels like a normal interface // implementation but is in fact over an RPC connection. databaseRPC := raw.(*databasePluginRPCClient) + // Wrap RPC implimentation in DatabasePluginClient return &DatabasePluginClient{ client: client, databasePluginRPCClient: databaseRPC, @@ -70,12 +71,11 @@ type databasePluginRPCClient struct { client *rpc.Client } -func (dr *databasePluginRPCClient) Type() string { +func (dr *databasePluginRPCClient) Type() (string, error) { var dbType string - //TODO: catch error - dr.client.Call("Plugin.Type", struct{}{}, &dbType) + err := dr.client.Call("Plugin.Type", struct{}{}, &dbType) - return fmt.Sprintf("plugin-%s", dbType) + return fmt.Sprintf("plugin-%s", dbType), err } func (dr *databasePluginRPCClient) CreateUser(statements Statements, usernamePrefix string, expiration time.Time) (username string, password string, err error) { diff --git a/builtin/logical/database/dbplugin/databasemiddleware.go b/builtin/logical/database/dbplugin/databasemiddleware.go index 2137cd9c388b..e28a8741e43c 100644 --- a/builtin/logical/database/dbplugin/databasemiddleware.go +++ b/builtin/logical/database/dbplugin/databasemiddleware.go @@ -18,7 +18,7 @@ type databaseTracingMiddleware struct { typeStr string } -func (mw *databaseTracingMiddleware) Type() string { +func (mw *databaseTracingMiddleware) Type() (string, error) { return mw.next.Type() } @@ -87,7 +87,7 @@ type databaseMetricsMiddleware struct { typeStr string } -func (mw *databaseMetricsMiddleware) Type() string { +func (mw *databaseMetricsMiddleware) Type() (string, error) { return mw.next.Type() } diff --git a/builtin/logical/database/dbplugin/plugin.go b/builtin/logical/database/dbplugin/plugin.go index dadb6639eea4..5e6ce939bea1 100644 --- a/builtin/logical/database/dbplugin/plugin.go +++ b/builtin/logical/database/dbplugin/plugin.go @@ -2,6 +2,7 @@ package dbplugin import ( "errors" + "fmt" "net/rpc" "time" @@ -16,7 +17,7 @@ var ( // DatabaseType is the interface that all database objects must implement. type DatabaseType interface { - Type() string + Type() (string, error) CreateUser(statements Statements, usernamePrefix string, expiration time.Time) (username string, password string, err error) RenewUser(statements Statements, username string, expiration time.Time) error RevokeUser(statements Statements, username string) error @@ -52,16 +53,21 @@ func PluginFactory(pluginName string, sys pluginutil.LookWrapper, logger log.Log return nil, err } + typeStr, err := db.Type() + if err != nil { + return nil, fmt.Errorf("error getting plugin type: %s", err) + } + // Wrap with metrics middleware db = &databaseMetricsMiddleware{ next: db, - typeStr: db.Type(), + typeStr: typeStr, } // Wrap with tracing middleware db = &databaseTracingMiddleware{ next: db, - typeStr: db.Type(), + typeStr: typeStr, logger: logger, } diff --git a/builtin/logical/database/dbplugin/plugin_test.go b/builtin/logical/database/dbplugin/plugin_test.go index 7909bbd4e53d..1587ba24a5b4 100644 --- a/builtin/logical/database/dbplugin/plugin_test.go +++ b/builtin/logical/database/dbplugin/plugin_test.go @@ -19,7 +19,7 @@ type mockPlugin struct { users map[string][]string } -func (m *mockPlugin) Type() string { return "mock" } +func (m *mockPlugin) Type() (string, error) { return "mock", nil } func (m *mockPlugin) CreateUser(statements dbplugin.Statements, usernamePrefix string, expiration time.Time) (username string, password string, err error) { err = errors.New("err") if usernamePrefix == "" || expiration.IsZero() { @@ -59,7 +59,7 @@ func (m *mockPlugin) RevokeUser(statements dbplugin.Statements, username string) delete(m.users, username) return nil } -func (m *mockPlugin) Initialize(conf map[string]interface{}) error { +func (m *mockPlugin) Initialize(conf map[string]interface{}, _ bool) error { err := errors.New("err") if len(conf) != 1 { return err @@ -108,7 +108,7 @@ func TestPlugin_Initialize(t *testing.T) { "test": 1, } - err = dbRaw.Initialize(connectionDetails) + err = dbRaw.Initialize(connectionDetails, true) if err != nil { t.Fatalf("err: %s", err) } @@ -133,7 +133,7 @@ func TestPlugin_CreateUser(t *testing.T) { "test": 1, } - err = db.Initialize(connectionDetails) + err = db.Initialize(connectionDetails, true) if err != nil { t.Fatalf("err: %s", err) } @@ -167,7 +167,7 @@ func TestPlugin_RenewUser(t *testing.T) { connectionDetails := map[string]interface{}{ "test": 1, } - err = db.Initialize(connectionDetails) + err = db.Initialize(connectionDetails, true) if err != nil { t.Fatalf("err: %s", err) } @@ -196,7 +196,7 @@ func TestPlugin_RevokeUser(t *testing.T) { connectionDetails := map[string]interface{}{ "test": 1, } - err = db.Initialize(connectionDetails) + err = db.Initialize(connectionDetails, true) if err != nil { t.Fatalf("err: %s", err) } diff --git a/builtin/logical/database/dbplugin/server.go b/builtin/logical/database/dbplugin/server.go index 326e25103cff..3a3e233946ec 100644 --- a/builtin/logical/database/dbplugin/server.go +++ b/builtin/logical/database/dbplugin/server.go @@ -42,8 +42,9 @@ type databasePluginRPCServer struct { } func (ds *databasePluginRPCServer) Type(_ struct{}, resp *string) error { - *resp = ds.impl.Type() - return nil + var err error + *resp, err = ds.impl.Type() + return err } func (ds *databasePluginRPCServer) CreateUser(args *CreateUserRequest, resp *CreateUserResponse) error { From 4c75326aad6b9a9a9ca1fdd4626007ab6dd4b10c Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Wed, 12 Apr 2017 17:35:02 -0700 Subject: [PATCH 074/152] Cleanup path files --- builtin/logical/database/backend.go | 10 +-- builtin/logical/database/dbplugin/plugin.go | 9 --- .../database/path_config_connection.go | 71 +++++++++++-------- builtin/logical/database/path_roles.go | 1 + command/{plugin-exec.go => plugin_exec.go} | 0 5 files changed, 46 insertions(+), 45 deletions(-) rename command/{plugin-exec.go => plugin_exec.go} (100%) diff --git a/builtin/logical/database/backend.go b/builtin/logical/database/backend.go index c8f9ad85411b..2ce7595260bd 100644 --- a/builtin/logical/database/backend.go +++ b/builtin/logical/database/backend.go @@ -12,7 +12,7 @@ import ( "github.com/hashicorp/vault/logical/framework" ) -const databaseConfigPath = "database/dbs/" +const databaseConfigPath = "database/config/" // DatabaseConfig is used by the Factory function to configure a DatabaseType // object. @@ -32,12 +32,6 @@ func Backend(conf *logical.BackendConfig) *databaseBackend { b.Backend = &framework.Backend{ Help: strings.TrimSpace(backendHelp), - PathsSpecial: &logical.Paths{ - Root: []string{ - "dbs/plugin/*", - }, - }, - Paths: []*framework.Path{ pathConfigurePluginConnection(&b), pathListRoles(&b), @@ -90,7 +84,7 @@ func (b *databaseBackend) getOrCreateDBObj(s logical.Storage, name string) (dbpl return db, nil } - entry, err := s.Get(fmt.Sprintf("dbs/%s", name)) + entry, err := s.Get(fmt.Sprintf("config/%s", name)) if err != nil { return nil, fmt.Errorf("failed to read connection configuration with name: %s", name) } diff --git a/builtin/logical/database/dbplugin/plugin.go b/builtin/logical/database/dbplugin/plugin.go index 5e6ce939bea1..61de0fe8ce05 100644 --- a/builtin/logical/database/dbplugin/plugin.go +++ b/builtin/logical/database/dbplugin/plugin.go @@ -1,7 +1,6 @@ package dbplugin import ( - "errors" "fmt" "net/rpc" "time" @@ -11,10 +10,6 @@ import ( log "github.com/mgutz/logxi/v1" ) -var ( - ErrEmptyPluginName = errors.New("empty plugin name") -) - // DatabaseType is the interface that all database objects must implement. type DatabaseType interface { Type() (string, error) @@ -37,10 +32,6 @@ type Statements struct { // PluginFactory is used to build plugin database types. It wraps the database // object in a logging and metrics middleware. func PluginFactory(pluginName string, sys pluginutil.LookWrapper, logger log.Logger) (DatabaseType, error) { - if pluginName == "" { - return nil, ErrEmptyPluginName - } - // Look for plugin in the plugin catalog pluginMeta, err := sys.LookupPlugin(pluginName) if err != nil { diff --git a/builtin/logical/database/path_config_connection.go b/builtin/logical/database/path_config_connection.go index 5817f53c2d8c..f69c7761b2e3 100644 --- a/builtin/logical/database/path_config_connection.go +++ b/builtin/logical/database/path_config_connection.go @@ -1,6 +1,7 @@ package database import ( + "errors" "fmt" "github.com/fatih/structs" @@ -9,6 +10,11 @@ import ( "github.com/hashicorp/vault/logical/framework" ) +var ( + respErrEmptyPluginName = logical.ErrorResponse("empty plugin name") + respErrEmptyName = logical.ErrorResponse("Empty name attribute given") +) + // pathResetConnection configures a path to reset a plugin. func pathResetConnection(b *databaseBackend) *framework.Path { return &framework.Path{ @@ -16,7 +22,7 @@ func pathResetConnection(b *databaseBackend) *framework.Path { Fields: map[string]*framework.FieldSchema{ "name": &framework.FieldSchema{ Type: framework.TypeString, - Description: "Name of this DB type", + Description: "Name of this database connection", }, }, @@ -35,15 +41,17 @@ func (b *databaseBackend) pathConnectionReset() framework.OperationFunc { return func(req *logical.Request, data *framework.FieldData) (*logical.Response, error) { name := data.Get("name").(string) if name == "" { - return logical.ErrorResponse("Empty name attribute given"), nil + return respErrEmptyName, nil } // Grab the mutex lock b.Lock() defer b.Unlock() + // Close plugin and delete the entry in the connections cache. b.clearConnection(name) + // Execute plugin again, we don't need the object so throw away. _, err := b.getOrCreateDBObj(req.Storage, name) if err != nil { return nil, err @@ -61,14 +69,7 @@ func pathConfigurePluginConnection(b *databaseBackend) *framework.Path { Fields: map[string]*framework.FieldSchema{ "name": &framework.FieldSchema{ Type: framework.TypeString, - Description: "Name of this DB type", - }, - - "verify_connection": &framework.FieldSchema{ - Type: framework.TypeBool, - Default: true, - Description: `If set, the connection details are verified by - actually connecting to the database`, + Description: "Name of this database connection", }, "plugin_name": &framework.FieldSchema{ @@ -77,6 +78,13 @@ func pathConfigurePluginConnection(b *databaseBackend) *framework.Path { plugin known to vault. This endpoint will create an instance of that plugin type.`, }, + + "verify_connection": &framework.FieldSchema{ + Type: framework.TypeBool, + Default: true, + Description: `If true, the connection details are verified by + actually connecting to the database. Defaults to true.`, + }, }, Callbacks: map[logical.Operation]framework.OperationFunc{ @@ -94,10 +102,13 @@ func pathConfigurePluginConnection(b *databaseBackend) *framework.Path { func (b *databaseBackend) connectionReadHandler() framework.OperationFunc { return func(req *logical.Request, data *framework.FieldData) (*logical.Response, error) { name := data.Get("name").(string) + if name == "" { + return respErrEmptyName, nil + } - entry, err := req.Storage.Get(fmt.Sprintf("dbs/%s", name)) + entry, err := req.Storage.Get(fmt.Sprintf("config/%s", name)) if err != nil { - return nil, fmt.Errorf("failed to read connection configuration") + return nil, errors.New("failed to read connection configuration") } if entry == nil { return nil, nil @@ -118,12 +129,12 @@ func (b *databaseBackend) connectionDeleteHandler() framework.OperationFunc { return func(req *logical.Request, data *framework.FieldData) (*logical.Response, error) { name := data.Get("name").(string) if name == "" { - return logical.ErrorResponse("Empty name attribute given"), nil + return respErrEmptyName, nil } - err := req.Storage.Delete(fmt.Sprintf("dbs/%s", name)) + err := req.Storage.Delete(fmt.Sprintf("config/%s", name)) if err != nil { - return nil, fmt.Errorf("failed to delete connection configuration") + return nil, errors.New("failed to delete connection configuration") } b.Lock() @@ -134,9 +145,9 @@ func (b *databaseBackend) connectionDeleteHandler() framework.OperationFunc { if err != nil { return nil, err } - } - delete(b.connections, name) + delete(b.connections, name) + } return nil, nil } @@ -146,22 +157,22 @@ func (b *databaseBackend) connectionDeleteHandler() framework.OperationFunc { // both builtin and plugin database types. func (b *databaseBackend) connectionWriteHandler() framework.OperationFunc { return func(req *logical.Request, data *framework.FieldData) (*logical.Response, error) { - - config := &DatabaseConfig{ - ConnectionDetails: data.Raw, - PluginName: data.Get("plugin_name").(string), + pluginName := data.Get("plugin_name").(string) + if pluginName == "" { + return respErrEmptyPluginName, nil } name := data.Get("name").(string) if name == "" { - return logical.ErrorResponse("Empty name attribute given"), nil + return respErrEmptyName, nil } verifyConnection := data.Get("verify_connection").(bool) - // Grab the mutex lock - b.Lock() - defer b.Unlock() + config := &DatabaseConfig{ + ConnectionDetails: data.Raw, + PluginName: pluginName, + } db, err := dbplugin.PluginFactory(config.PluginName, b.System(), b.logger) if err != nil { @@ -174,6 +185,10 @@ func (b *databaseBackend) connectionWriteHandler() framework.OperationFunc { return logical.ErrorResponse(fmt.Sprintf("Error creating database object: %s", err)), nil } + // Grab the mutex lock + b.Lock() + defer b.Unlock() + if _, ok := b.connections[name]; ok { // Close and remove the old connection err := b.connections[name].Close() @@ -189,7 +204,7 @@ func (b *databaseBackend) connectionWriteHandler() framework.OperationFunc { b.connections[name] = db // Store it - entry, err := logical.StorageEntryJSON(fmt.Sprintf("dbs/%s", name), config) + entry, err := logical.StorageEntryJSON(fmt.Sprintf("config/%s", name), config) if err != nil { return nil, err } @@ -198,7 +213,7 @@ func (b *databaseBackend) connectionWriteHandler() framework.OperationFunc { } resp := &logical.Response{} - resp.AddWarning("Read access to this endpoint should be controlled via ACLs as it will return the connection string or URL as it is, including passwords, if any.") + resp.AddWarning("Read access to this endpoint should be controlled via ACLs as it will return the connection details as is, including passwords, if any.") return resp, nil } @@ -221,7 +236,7 @@ accepts: plugin known to vault. This endpoint will create an instance of that plugin type. - * "verify_connection" - A boolean value denoting if the plugin should verify + * "verify_connection" (default: true) - A boolean value denoting if the plugin should verify it is able to connect to the database using the provided connection details. ` diff --git a/builtin/logical/database/path_roles.go b/builtin/logical/database/path_roles.go index 263a555e6b3b..b3393b1ba884 100644 --- a/builtin/logical/database/path_roles.go +++ b/builtin/logical/database/path_roles.go @@ -109,6 +109,7 @@ func (b *databaseBackend) pathRoleRead() framework.OperationFunc { return &logical.Response{ Data: map[string]interface{}{ + "db_name": role.DBName, "creation_statements": role.Statements.CreationStatements, "revocation_statements": role.Statements.RevocationStatements, "rollback_statements": role.Statements.RollbackStatements, diff --git a/command/plugin-exec.go b/command/plugin_exec.go similarity index 100% rename from command/plugin-exec.go rename to command/plugin_exec.go From 33d66f3a67c41db9eb5770e41983dbf18f8918f8 Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Wed, 12 Apr 2017 17:35:53 -0700 Subject: [PATCH 075/152] Add comments to the plugin runner --- helper/pluginutil/runner.go | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/helper/pluginutil/runner.go b/helper/pluginutil/runner.go index 4d66d8706bce..a57abad0edc5 100644 --- a/helper/pluginutil/runner.go +++ b/helper/pluginutil/runner.go @@ -17,20 +17,28 @@ var ( PluginMlockEnabled = "VAULT_PLUGIN_MLOCK_ENABLED" ) +// Looker defines the plugin Lookup function that looks into the plugin catalog +// for availible plugins and returns a PluginRunner type Looker interface { LookupPlugin(string) (*PluginRunner, error) } +// Wrapper interface defines the functions needed by the runner to wrap the +// metadata needed to run a plugin process. This includes looking up Mlock +// configuration and wrapping data in a respose wrapped token. type Wrapper interface { ResponseWrapData(data map[string]interface{}, ttl time.Duration, jwt bool) (string, error) MlockDisabled() bool } +// LookWrapper defines the functions for both Looker and Wrapper type LookWrapper interface { Looker Wrapper } +// PluginRunner defines the metadata needed to run a plugin securely with +// go-plugin. type PluginRunner struct { Name string `json:"name"` Command string `json:"command"` @@ -39,6 +47,8 @@ type PluginRunner struct { Builtin bool `json:"builtin"` } +// Run takes a wrapper instance, and the go-plugin paramaters and executes a +// plugin. func (r *PluginRunner) Run(wrapper Wrapper, pluginMap map[string]plugin.Plugin, hs plugin.HandshakeConfig, env []string) (*plugin.Client, error) { // Get a CA TLS Certificate CACertBytes, CACert, CAKey, err := GenerateCACert() @@ -87,6 +97,8 @@ func (r *PluginRunner) Run(wrapper Wrapper, pluginMap map[string]plugin.Plugin, return client, nil } +// OptionallyEnableMlock determines if mlock should be called, and if so enables +// mlock. func OptionallyEnableMlock() error { if os.Getenv(PluginMlockEnabled) == "true" { return mlock.LockMemory() From b20c17745c549240baffd92954c3df672120e701 Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Thu, 13 Apr 2017 10:33:34 -0700 Subject: [PATCH 076/152] Add allowed_roles parameter and checks --- builtin/logical/database/backend.go | 36 +++---- builtin/logical/database/backend_test.go | 101 ++++++++++++++++++ .../database/path_config_connection.go | 31 +++++- builtin/logical/database/path_role_create.go | 12 +++ 4 files changed, 158 insertions(+), 22 deletions(-) diff --git a/builtin/logical/database/backend.go b/builtin/logical/database/backend.go index 2ce7595260bd..e57fa19c180d 100644 --- a/builtin/logical/database/backend.go +++ b/builtin/logical/database/backend.go @@ -14,15 +14,6 @@ import ( const databaseConfigPath = "database/config/" -// DatabaseConfig is used by the Factory function to configure a DatabaseType -// object. -type DatabaseConfig struct { - PluginName string `json:"plugin_name" structs:"plugin_name" mapstructure:"plugin_name"` - // ConnectionDetails stores the database specific connection settings needed - // by each database type. - ConnectionDetails map[string]interface{} `json:"connection_details" structs:"connection_details" mapstructure:"connection_details"` -} - func Factory(conf *logical.BackendConfig) (logical.Backend, error) { return Backend(conf).Setup(conf) } @@ -84,16 +75,8 @@ func (b *databaseBackend) getOrCreateDBObj(s logical.Storage, name string) (dbpl return db, nil } - entry, err := s.Get(fmt.Sprintf("config/%s", name)) + config, err := b.DatabaseConfig(s, name) if err != nil { - return nil, fmt.Errorf("failed to read connection configuration with name: %s", name) - } - if entry == nil { - return nil, fmt.Errorf("failed to find entry for connection with name: %s", name) - } - - var config DatabaseConfig - if err := entry.DecodeJSON(&config); err != nil { return nil, err } @@ -112,6 +95,23 @@ func (b *databaseBackend) getOrCreateDBObj(s logical.Storage, name string) (dbpl return db, nil } +func (b *databaseBackend) DatabaseConfig(s logical.Storage, name string) (*DatabaseConfig, error) { + entry, err := s.Get(fmt.Sprintf("config/%s", name)) + if err != nil { + return nil, fmt.Errorf("failed to read connection configuration with name: %s", name) + } + if entry == nil { + return nil, fmt.Errorf("failed to find entry for connection with name: %s", name) + } + + var config DatabaseConfig + if err := entry.DecodeJSON(&config); err != nil { + return nil, err + } + + return &config, nil +} + func (b *databaseBackend) Role(s logical.Storage, n string) (*roleEntry, error) { entry, err := s.Get("role/" + n) if err != nil { diff --git a/builtin/logical/database/backend_test.go b/builtin/logical/database/backend_test.go index 5b3a0db42abe..2615577fdb51 100644 --- a/builtin/logical/database/backend_test.go +++ b/builtin/logical/database/backend_test.go @@ -130,6 +130,7 @@ func TestBackend_config_connection(t *testing.T) { expected := map[string]interface{}{ "plugin_name": "postgresql-database-plugin", "connection_details": configData, + "allowed_roles": []string{}, } configReq.Operation = logical.ReadOperation resp, err = b.HandleRequest(configReq) @@ -306,6 +307,7 @@ func TestBackend_connectionCrud(t *testing.T) { expected := map[string]interface{}{ "plugin_name": "postgresql-database-plugin", "connection_details": data, + "allowed_roles": []string{}, } req.Operation = logical.ReadOperation resp, err = b.HandleRequest(req) @@ -484,6 +486,105 @@ func TestBackend_roleCrud(t *testing.T) { t.Fatal("Expected response to be nil") } } +func TestBackend_allowedRoles(t *testing.T) { + _, ln, sys, _ := getCore(t) + defer ln.Close() + + config := logical.TestBackendConfig() + config.StorageView = &logical.InmemStorage{} + config.System = sys + + b, err := Factory(config) + if err != nil { + t.Fatal(err) + } + defer b.Cleanup() + + cleanup, connURL := preparePostgresTestContainer(t, config.StorageView, b) + defer cleanup() + + // Configure a connection + data := map[string]interface{}{ + "connection_url": connURL, + "plugin_name": "postgresql-database-plugin", + "allowed_roles": "allow, allowed", + } + req := &logical.Request{ + Operation: logical.UpdateOperation, + Path: "config/plugin-test", + Storage: config.StorageView, + Data: data, + } + resp, err := b.HandleRequest(req) + if err != nil || (resp != nil && resp.IsError()) { + t.Fatalf("err:%s resp:%#v\n", err, resp) + } + + // Create a denied and an allowed role + data = map[string]interface{}{ + "db_name": "plugin-test", + "creation_statements": testRole, + "default_ttl": "5m", + "max_ttl": "10m", + } + req = &logical.Request{ + Operation: logical.UpdateOperation, + Path: "roles/denied", + Storage: config.StorageView, + Data: data, + } + resp, err = b.HandleRequest(req) + if err != nil || (resp != nil && resp.IsError()) { + t.Fatalf("err:%s resp:%#v\n", err, resp) + } + + data = map[string]interface{}{ + "db_name": "plugin-test", + "creation_statements": testRole, + "default_ttl": "5m", + "max_ttl": "10m", + } + req = &logical.Request{ + Operation: logical.UpdateOperation, + Path: "roles/allowed", + Storage: config.StorageView, + Data: data, + } + resp, err = b.HandleRequest(req) + if err != nil || (resp != nil && resp.IsError()) { + t.Fatalf("err:%s resp:%#v\n", err, resp) + } + + // Get creds from denied role, should fail + data = map[string]interface{}{} + req = &logical.Request{ + Operation: logical.ReadOperation, + Path: "creds/denied", + Storage: config.StorageView, + Data: data, + } + credsResp, err := b.HandleRequest(req) + if err != logical.ErrPermissionDenied { + t.Fatalf("expected error to be:%s got:%#v\n", logical.ErrPermissionDenied, err) + } + + // Get creds from allowed role, should work. + data = map[string]interface{}{} + req = &logical.Request{ + Operation: logical.ReadOperation, + Path: "creds/allowed", + Storage: config.StorageView, + Data: data, + } + credsResp, err = b.HandleRequest(req) + if err != nil || (credsResp != nil && credsResp.IsError()) { + t.Fatalf("err:%s resp:%#v\n", err, credsResp) + } + + if !testCredsExist(t, credsResp, connURL) { + t.Fatalf("Creds should exist") + } +} func testCredsExist(t *testing.T, resp *logical.Response, connURL string) bool { var d struct { diff --git a/builtin/logical/database/path_config_connection.go b/builtin/logical/database/path_config_connection.go index f69c7761b2e3..2a0022b4d869 100644 --- a/builtin/logical/database/path_config_connection.go +++ b/builtin/logical/database/path_config_connection.go @@ -6,15 +6,26 @@ import ( "github.com/fatih/structs" "github.com/hashicorp/vault/builtin/logical/database/dbplugin" + "github.com/hashicorp/vault/helper/strutil" "github.com/hashicorp/vault/logical" "github.com/hashicorp/vault/logical/framework" ) var ( - respErrEmptyPluginName = logical.ErrorResponse("empty plugin name") + respErrEmptyPluginName = logical.ErrorResponse("Empty plugin name") respErrEmptyName = logical.ErrorResponse("Empty name attribute given") ) +// DatabaseConfig is used by the Factory function to configure a DatabaseType +// object. +type DatabaseConfig struct { + PluginName string `json:"plugin_name" structs:"plugin_name" mapstructure:"plugin_name"` + // ConnectionDetails stores the database specific connection settings needed + // by each database type. + ConnectionDetails map[string]interface{} `json:"connection_details" structs:"connection_details" mapstructure:"connection_details"` + AllowedRoles []string `json:"allowed_roles" structs:"allowed_roles" mapstructure:"allowed_roles"` +} + // pathResetConnection configures a path to reset a plugin. func pathResetConnection(b *databaseBackend) *framework.Path { return &framework.Path{ @@ -75,15 +86,22 @@ func pathConfigurePluginConnection(b *databaseBackend) *framework.Path { "plugin_name": &framework.FieldSchema{ Type: framework.TypeString, Description: `The name of a builtin or previously registered - plugin known to vault. This endpoint will create an instance of - that plugin type.`, + plugin known to vault. This endpoint will create an instance of + that plugin type.`, }, "verify_connection": &framework.FieldSchema{ Type: framework.TypeBool, Default: true, Description: `If true, the connection details are verified by - actually connecting to the database. Defaults to true.`, + actually connecting to the database. Defaults to true.`, + }, + + "allowed_roles": &framework.FieldSchema{ + Type: framework.TypeString, + Description: `Comma separated list of the role names allowed to + get creds from this database connection. If not set all roles + are allowed.`, }, }, @@ -169,9 +187,14 @@ func (b *databaseBackend) connectionWriteHandler() framework.OperationFunc { verifyConnection := data.Get("verify_connection").(bool) + // Pasrse and dedupe allowed roles from a comma separated string. + allowedRolesRaw := data.Get("allowed_roles").(string) + allowedRoles := strutil.ParseDedupAndSortStrings(allowedRolesRaw, ",") + config := &DatabaseConfig{ ConnectionDetails: data.Raw, PluginName: pluginName, + AllowedRoles: allowedRoles, } db, err := dbplugin.PluginFactory(config.PluginName, b.System(), b.logger) diff --git a/builtin/logical/database/path_role_create.go b/builtin/logical/database/path_role_create.go index 59584e9437cf..631802dff643 100644 --- a/builtin/logical/database/path_role_create.go +++ b/builtin/logical/database/path_role_create.go @@ -4,6 +4,7 @@ import ( "fmt" "time" + "github.com/hashicorp/vault/helper/strutil" "github.com/hashicorp/vault/logical" "github.com/hashicorp/vault/logical/framework" ) @@ -40,6 +41,17 @@ func (b *databaseBackend) pathRoleCreateRead() framework.OperationFunc { return logical.ErrorResponse(fmt.Sprintf("Unknown role: %s", name)), nil } + dbConfig, err := b.DatabaseConfig(req.Storage, role.DBName) + if err != nil { + return nil, err + } + + // If role name isn't in the database's allowed roles, send back a + // permission denied. + if len(dbConfig.AllowedRoles) > 0 && !strutil.StrListContains(dbConfig.AllowedRoles, name) { + return nil, logical.ErrPermissionDenied + } + b.Lock() defer b.Unlock() From 07f3f4fc264ac99e1cfe80030a6968f8ccd36d7d Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Thu, 13 Apr 2017 11:22:53 -0700 Subject: [PATCH 077/152] Update the plugin directory logic --- command/server.go | 7 +++++++ command/server/config.go | 5 +++++ 2 files changed, 12 insertions(+) diff --git a/command/server.go b/command/server.go index 17f40d1655c9..c402d09d5244 100644 --- a/command/server.go +++ b/command/server.go @@ -284,6 +284,13 @@ func (c *ServerCommand) Run(args []string) int { return 1 } coreConfig.PluginDirectory = filepath.Join(homePath, "/.vault-plugins/") + err = os.Mkdir(coreConfig.PluginDirectory, 0700) + if err != nil && !os.IsExist(err) { + c.Ui.Output(fmt.Sprintf( + "Error making default plugin directory: %v", err)) + return 1 + } + } var disableClustering bool diff --git a/command/server/config.go b/command/server/config.go index 4821a29ba8bf..dad485928d8e 100644 --- a/command/server/config.go +++ b/command/server/config.go @@ -273,6 +273,11 @@ func (c *Config) Merge(c2 *Config) *Config { result.EnableUI = c2.EnableUI } + result.PluginDirectory = c.PluginDirectory + if c2.PluginDirectory != "" { + result.PluginDirectory = c2.PluginDirectory + } + return result } From be50cbae91a3af52f72d2ead17222bb300cfe8ea Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Thu, 13 Apr 2017 13:48:32 -0700 Subject: [PATCH 078/152] Move plugins into main vault repo --- helper/builtinplugins/builtin.go | 4 +- .../mssql/mssql-database-plugin/main.go | 16 + plugins/database/mssql/mssql.go | 268 ++++++++++++++ plugins/database/mssql/mssql_test.go | 173 +++++++++ .../mysql/mysql-database-plugin/main.go | 16 + plugins/database/mysql/mysql.go | 183 ++++++++++ plugins/database/mysql/mysql_test.go | 200 +++++++++++ .../postgresql-database-plugin/main.go | 16 + plugins/database/postgresql/postgresql.go | 337 ++++++++++++++++++ .../database/postgresql/postgresql_test.go | 308 ++++++++++++++++ plugins/helper/database/connutil/cassandra.go | 172 +++++++++ plugins/helper/database/connutil/connutil.go | 21 ++ plugins/helper/database/connutil/sql.go | 131 +++++++ .../helper/database/credsutil/cassandra.go | 37 ++ .../helper/database/credsutil/credsutil.go | 12 + plugins/helper/database/credsutil/sql.go | 43 +++ plugins/helper/database/dbutil/dbutil.go | 20 ++ 17 files changed, 1955 insertions(+), 2 deletions(-) create mode 100644 plugins/database/mssql/mssql-database-plugin/main.go create mode 100644 plugins/database/mssql/mssql.go create mode 100644 plugins/database/mssql/mssql_test.go create mode 100644 plugins/database/mysql/mysql-database-plugin/main.go create mode 100644 plugins/database/mysql/mysql.go create mode 100644 plugins/database/mysql/mysql_test.go create mode 100644 plugins/database/postgresql/postgresql-database-plugin/main.go create mode 100644 plugins/database/postgresql/postgresql.go create mode 100644 plugins/database/postgresql/postgresql_test.go create mode 100644 plugins/helper/database/connutil/cassandra.go create mode 100644 plugins/helper/database/connutil/connutil.go create mode 100644 plugins/helper/database/connutil/sql.go create mode 100644 plugins/helper/database/credsutil/cassandra.go create mode 100644 plugins/helper/database/credsutil/credsutil.go create mode 100644 plugins/helper/database/credsutil/sql.go create mode 100644 plugins/helper/database/dbutil/dbutil.go diff --git a/helper/builtinplugins/builtin.go b/helper/builtinplugins/builtin.go index 55da9a97f310..beedbb15b857 100644 --- a/helper/builtinplugins/builtin.go +++ b/helper/builtinplugins/builtin.go @@ -1,8 +1,8 @@ package builtinplugins import ( - "github.com/hashicorp/vault-plugins/database/mysql" - "github.com/hashicorp/vault-plugins/database/postgresql" + "github.com/hashicorp/vault/plugins/database/mysql" + "github.com/hashicorp/vault/plugins/database/postgresql" ) var BuiltinPlugins *builtinPlugins = &builtinPlugins{ diff --git a/plugins/database/mssql/mssql-database-plugin/main.go b/plugins/database/mssql/mssql-database-plugin/main.go new file mode 100644 index 000000000000..ead1cf842306 --- /dev/null +++ b/plugins/database/mssql/mssql-database-plugin/main.go @@ -0,0 +1,16 @@ +package main + +import ( + "fmt" + "os" + + "github.com/hashicorp/vault/plugins/database/mssql" +) + +func main() { + err := mssql.Run() + if err != nil { + fmt.Println(err) + os.Exit(1) + } +} diff --git a/plugins/database/mssql/mssql.go b/plugins/database/mssql/mssql.go new file mode 100644 index 000000000000..567a095b664a --- /dev/null +++ b/plugins/database/mssql/mssql.go @@ -0,0 +1,268 @@ +package mssql + +import ( + "database/sql" + "fmt" + "strings" + "time" + + "github.com/hashicorp/vault/builtin/logical/database/dbplugin" + "github.com/hashicorp/vault/helper/strutil" + "github.com/hashicorp/vault/plugins/helper/database/connutil" + "github.com/hashicorp/vault/plugins/helper/database/credsutil" + "github.com/hashicorp/vault/plugins/helper/database/dbutil" +) + +const msSQLTypeName = "mssql" + +// MSSQL is an implementation of DatabaseType interface +type MSSQL struct { + connutil.ConnectionProducer + credsutil.CredentialsProducer +} + +func New() *MSSQL { + connProducer := &connutil.SQLConnectionProducer{} + connProducer.Type = msSQLTypeName + + credsProducer := &credsutil.SQLCredentialsProducer{ + DisplayNameLen: 4, + UsernameLen: 16, + } + + dbType := &MSSQL{ + ConnectionProducer: connProducer, + CredentialsProducer: credsProducer, + } + + return dbType +} + +// Run instantiates a MSSQL object, and runs the RPC server for the plugin +func Run() error { + dbType := New() + + dbplugin.NewPluginServer(dbType) + + return nil +} + +// Type returns the TypeName for this backend +func (m *MSSQL) Type() (string, error) { + return msSQLTypeName, nil +} + +func (m *MSSQL) getConnection() (*sql.DB, error) { + db, err := m.Connection() + if err != nil { + return nil, err + } + + return db.(*sql.DB), nil +} + +// CreateUser generates the username/password on the underlying MSSQL secret backend as instructed by +// the CreationStatement provided. +func (m *MSSQL) CreateUser(statements dbplugin.Statements, usernamePrefix string, expiration time.Time) (username string, password string, err error) { + // Grab the lock + m.Lock() + defer m.Unlock() + + // Get the connection + db, err := m.getConnection() + if err != nil { + return "", "", err + } + + if statements.CreationStatements == "" { + return "", "", dbutil.ErrEmptyCreationStatement + } + + username, err = m.GenerateUsername(usernamePrefix) + if err != nil { + return "", "", err + } + + password, err = m.GeneratePassword() + if err != nil { + return "", "", err + } + + expirationStr, err := m.GenerateExpiration(expiration) + if err != nil { + return "", "", err + } + + // Start a transaction + tx, err := db.Begin() + if err != nil { + return "", "", err + } + defer tx.Rollback() + + // Execute each query + for _, query := range strutil.ParseArbitraryStringSlice(statements.CreationStatements, ";") { + query = strings.TrimSpace(query) + if len(query) == 0 { + continue + } + + stmt, err := tx.Prepare(dbutil.QueryHelper(query, map[string]string{ + "name": username, + "password": password, + "expiration": expirationStr, + })) + if err != nil { + return "", "", err + } + defer stmt.Close() + if _, err := stmt.Exec(); err != nil { + return "", "", err + } + } + + // Commit the transaction + if err := tx.Commit(); err != nil { + return "", "", err + } + + return username, password, nil +} + +// RenewUser is not supported on MSSQL, so this is a no-op. +func (m *MSSQL) RenewUser(statements dbplugin.Statements, username string, expiration time.Time) error { + // NOOP + return nil +} + +// RevokeUser attempts to drop the specified user. It will first attempt to disable login, +// then kill pending connections from that user, and finally drop the user and login from the +// database instance. +func (m *MSSQL) RevokeUser(statements dbplugin.Statements, username string) error { + // Get connection + db, err := m.getConnection() + if err != nil { + return err + } + + // First disable server login + disableStmt, err := db.Prepare(fmt.Sprintf("ALTER LOGIN [%s] DISABLE;", username)) + if err != nil { + return err + } + defer disableStmt.Close() + if _, err := disableStmt.Exec(); err != nil { + return err + } + + // Query for sessions for the login so that we can kill any outstanding + // sessions. There cannot be any active sessions before we drop the logins + // This isn't done in a transaction because even if we fail along the way, + // we want to remove as much access as possible + sessionStmt, err := db.Prepare(fmt.Sprintf( + "SELECT session_id FROM sys.dm_exec_sessions WHERE login_name = '%s';", username)) + if err != nil { + return err + } + defer sessionStmt.Close() + + sessionRows, err := sessionStmt.Query() + if err != nil { + return err + } + defer sessionRows.Close() + + var revokeStmts []string + for sessionRows.Next() { + var sessionID int + err = sessionRows.Scan(&sessionID) + if err != nil { + return err + } + revokeStmts = append(revokeStmts, fmt.Sprintf("KILL %d;", sessionID)) + } + + // Query for database users using undocumented stored procedure for now since + // it is the easiest way to get this information; + // we need to drop the database users before we can drop the login and the role + // This isn't done in a transaction because even if we fail along the way, + // we want to remove as much access as possible + stmt, err := db.Prepare(fmt.Sprintf("EXEC sp_msloginmappings '%s';", username)) + if err != nil { + return err + } + defer stmt.Close() + + rows, err := stmt.Query() + if err != nil { + return err + } + defer rows.Close() + + for rows.Next() { + var loginName, dbName, qUsername string + var aliasName sql.NullString + err = rows.Scan(&loginName, &dbName, &qUsername, &aliasName) + if err != nil { + return err + } + revokeStmts = append(revokeStmts, fmt.Sprintf(dropUserSQL, dbName, username, username)) + } + + // we do not stop on error, as we want to remove as + // many permissions as possible right now + var lastStmtError error + for _, query := range revokeStmts { + stmt, err := db.Prepare(query) + if err != nil { + lastStmtError = err + continue + } + defer stmt.Close() + _, err = stmt.Exec() + if err != nil { + lastStmtError = err + } + } + + // can't drop if not all database users are dropped + if rows.Err() != nil { + return fmt.Errorf("cound not generate sql statements for all rows: %s", rows.Err()) + } + if lastStmtError != nil { + return fmt.Errorf("could not perform all sql statements: %s", lastStmtError) + } + + // Drop this login + stmt, err = db.Prepare(fmt.Sprintf(dropLoginSQL, username, username)) + if err != nil { + return err + } + defer stmt.Close() + if _, err := stmt.Exec(); err != nil { + return err + } + + return nil +} + +const dropUserSQL = ` +USE [%s] +IF EXISTS + (SELECT name + FROM sys.database_principals + WHERE name = N'%s') +BEGIN + DROP USER [%s] +END +` + +const dropLoginSQL = ` +IF EXISTS + (SELECT name + FROM master.sys.server_principals + WHERE name = N'%s') +BEGIN + DROP LOGIN [%s] +END +` diff --git a/plugins/database/mssql/mssql_test.go b/plugins/database/mssql/mssql_test.go new file mode 100644 index 000000000000..bc182f26fd4c --- /dev/null +++ b/plugins/database/mssql/mssql_test.go @@ -0,0 +1,173 @@ +package mssql + +import ( + "database/sql" + "fmt" + "os" + "strings" + "sync" + "testing" + "time" + + "github.com/hashicorp/vault/builtin/logical/database/dbplugin" + "github.com/hashicorp/vault/plugins/helper/database/connutil" + dockertest "gopkg.in/ory-am/dockertest.v3" +) + +var ( + testMSQLImagePull sync.Once +) + +func prepareMSSQLTestContainer(t *testing.T) (cleanup func(), retURL string) { + if os.Getenv("MSSQL_URL") != "" { + return func() {}, os.Getenv("MSSQL_URL") + } + + pool, err := dockertest.NewPool("") + if err != nil { + t.Fatalf("Failed to connect to docker: %s", err) + } + + resource, err := pool.Run("microsoft/mssql-server-linux", "latest", []string{"ACCEPT_EULA=Y", "SA_PASSWORD=yourStrong(!)Password"}) + if err != nil { + t.Fatalf("Could not start local MSSQL docker container: %s", err) + } + + cleanup = func() { + err := pool.Purge(resource) + if err != nil { + t.Fatalf("Failed to cleanup local container: %s", err) + } + } + + retURL = fmt.Sprintf("sqlserver://sa:yourStrong(!)Password@localhost:%s", resource.GetPort("1433/tcp")) + + // exponential backoff-retry + if err = pool.Retry(func() error { + var err error + var db *sql.DB + db, err = sql.Open("mssql", retURL) + if err != nil { + return err + } + return db.Ping() + }); err != nil { + t.Fatalf("Could not connect to MSSQL docker container: %s", err) + } + + return +} + +func TestMSSQL_Initialize(t *testing.T) { + cleanup, connURL := prepareMSSQLTestContainer(t) + defer cleanup() + + connectionDetails := map[string]interface{}{ + "connection_url": connURL, + } + + db := New() + + err := db.Initialize(connectionDetails, true) + if err != nil { + t.Fatalf("err: %s", err) + } + + connProducer := db.ConnectionProducer.(*connutil.SQLConnectionProducer) + if !connProducer.Initialized { + t.Fatal("Database should be initalized") + } + + err = db.Close() + if err != nil { + t.Fatalf("err: %s", err) + } +} + +func TestMSSQL_CreateUser(t *testing.T) { + cleanup, connURL := prepareMSSQLTestContainer(t) + defer cleanup() + + connectionDetails := map[string]interface{}{ + "connection_url": connURL, + } + + db := New() + err := db.Initialize(connectionDetails, true) + if err != nil { + t.Fatalf("err: %s", err) + } + + // Test with no configured Creation Statememt + _, _, err = db.CreateUser(dbplugin.Statements{}, "test", time.Now().Add(time.Minute)) + if err == nil { + t.Fatal("Expected error when no creation statement is provided") + } + + statements := dbplugin.Statements{ + CreationStatements: testMSSQLRole, + } + + username, password, err := db.CreateUser(statements, "test", time.Now().Add(time.Minute)) + if err != nil { + t.Fatalf("err: %s", err) + } + + if err = testCredsExist(t, connURL, username, password); err != nil { + t.Fatalf("Could not connect with new credentials: %s", err) + } +} + +func TestMSSQL_RevokeUser(t *testing.T) { + cleanup, connURL := prepareMSSQLTestContainer(t) + defer cleanup() + + connectionDetails := map[string]interface{}{ + "connection_url": connURL, + } + + db := New() + err := db.Initialize(connectionDetails, true) + if err != nil { + t.Fatalf("err: %s", err) + } + + statements := dbplugin.Statements{ + CreationStatements: testMSSQLRole, + } + + username, password, err := db.CreateUser(statements, "test", time.Now().Add(2*time.Second)) + if err != nil { + t.Fatalf("err: %s", err) + } + + if err = testCredsExist(t, connURL, username, password); err != nil { + t.Fatalf("Could not connect with new credentials: %s", err) + } + + // Test default revoke statememts + err = db.RevokeUser(statements, username) + if err != nil { + t.Fatalf("err: %s", err) + } + + if err := testCredsExist(t, connURL, username, password); err == nil { + t.Fatal("Credentials were not revoked") + } +} + +func testCredsExist(t testing.TB, connURL, username, password string) error { + // Log in with the new creds + connURL = strings.Replace(connURL, "sa:yourStrong(!)Password", fmt.Sprintf("%s:%s", username, password), 1) + db, err := sql.Open("mssql", connURL) + if err != nil { + return err + } + defer db.Close() + return db.Ping() +} + +const testMSSQLRole = ` +CREATE LOGIN [{{name}}] WITH PASSWORD = '{{password}}'; +CREATE USER [{{name}}] FOR LOGIN [{{name}}]; +GRANT SELECT, INSERT, UPDATE, DELETE ON SCHEMA::dbo TO [{{name}}];` diff --git a/plugins/database/mysql/mysql-database-plugin/main.go b/plugins/database/mysql/mysql-database-plugin/main.go new file mode 100644 index 000000000000..c0ec75c9cdc0 --- /dev/null +++ b/plugins/database/mysql/mysql-database-plugin/main.go @@ -0,0 +1,16 @@ +package main + +import ( + "fmt" + "os" + + "github.com/hashicorp/vault/plugins/database/mysql" +) + +func main() { + err := mysql.Run() + if err != nil { + fmt.Println(err) + os.Exit(1) + } +} diff --git a/plugins/database/mysql/mysql.go b/plugins/database/mysql/mysql.go new file mode 100644 index 000000000000..ea14a6782b84 --- /dev/null +++ b/plugins/database/mysql/mysql.go @@ -0,0 +1,183 @@ +package mysql + +import ( + "database/sql" + "strings" + "time" + + "github.com/hashicorp/vault/builtin/logical/database/dbplugin" + "github.com/hashicorp/vault/helper/strutil" + "github.com/hashicorp/vault/plugins/helper/database/connutil" + "github.com/hashicorp/vault/plugins/helper/database/credsutil" + "github.com/hashicorp/vault/plugins/helper/database/dbutil" +) + +const defaultMysqlRevocationStmts = ` + REVOKE ALL PRIVILEGES, GRANT OPTION FROM '{{name}}'@'%'; + DROP USER '{{name}}'@'%' +` +const mySQLTypeName = "mysql" + +type MySQL struct { + connutil.ConnectionProducer + credsutil.CredentialsProducer +} + +func New() *MySQL { + connProducer := &connutil.SQLConnectionProducer{} + connProducer.Type = mySQLTypeName + + credsProducer := &credsutil.SQLCredentialsProducer{ + DisplayNameLen: 4, + UsernameLen: 16, + } + + dbType := &MySQL{ + ConnectionProducer: connProducer, + CredentialsProducer: credsProducer, + } + + return dbType +} + +// Run instantiates a MySQL object, and runs the RPC server for the plugin +func Run() error { + dbType := New() + + dbplugin.NewPluginServer(dbType) + + return nil +} + +func (m *MySQL) Type() (string, error) { + return mySQLTypeName, nil +} + +func (m *MySQL) getConnection() (*sql.DB, error) { + db, err := m.Connection() + if err != nil { + return nil, err + } + + return db.(*sql.DB), nil +} + +func (m *MySQL) CreateUser(statements dbplugin.Statements, usernamePrefix string, expiration time.Time) (username string, password string, err error) { + // Grab the lock + m.Lock() + defer m.Unlock() + + // Get the connection + db, err := m.getConnection() + if err != nil { + return "", "", err + } + + if statements.CreationStatements == "" { + return "", "", dbutil.ErrEmptyCreationStatement + } + + username, err = m.GenerateUsername(usernamePrefix) + if err != nil { + return "", "", err + } + + password, err = m.GeneratePassword() + if err != nil { + return "", "", err + } + + expirationStr, err := m.GenerateExpiration(expiration) + if err != nil { + return "", "", err + } + + // Start a transaction + tx, err := db.Begin() + if err != nil { + return "", "", err + } + defer tx.Rollback() + + // Execute each query + for _, query := range strutil.ParseArbitraryStringSlice(statements.CreationStatements, ";") { + query = strings.TrimSpace(query) + if len(query) == 0 { + continue + } + + stmt, err := tx.Prepare(dbutil.QueryHelper(query, map[string]string{ + "name": username, + "password": password, + "expiration": expirationStr, + })) + if err != nil { + return "", "", err + } + defer stmt.Close() + if _, err := stmt.Exec(); err != nil { + return "", "", err + } + } + + // Commit the transaction + if err := tx.Commit(); err != nil { + return "", "", err + } + + return username, password, nil +} + +// NOOP +func (m *MySQL) RenewUser(statements dbplugin.Statements, username string, expiration time.Time) error { + return nil +} + +func (m *MySQL) RevokeUser(statements dbplugin.Statements, username string) error { + // Grab the read lock + m.Lock() + defer m.Unlock() + + // Get the connection + db, err := m.getConnection() + if err != nil { + return err + } + + revocationStmts := statements.RevocationStatements + // Use a default SQL statement for revocation if one cannot be fetched from the role + if revocationStmts == "" { + revocationStmts = defaultMysqlRevocationStmts + } + + // Start a transaction + tx, err := db.Begin() + if err != nil { + return err + } + defer tx.Rollback() + + for _, query := range strutil.ParseArbitraryStringSlice(revocationStmts, ";") { + query = strings.TrimSpace(query) + if len(query) == 0 { + continue + } + + // This is not a prepared statement because not all commands are supported + // 1295: This command is not supported in the prepared statement protocol yet + // Reference https://mariadb.com/kb/en/mariadb/prepare-statement/ + query = strings.Replace(query, "{{name}}", username, -1) + _, err = tx.Exec(query) + if err != nil { + return err + } + + } + + // Commit the transaction + if err := tx.Commit(); err != nil { + return err + } + + return nil +} diff --git a/plugins/database/mysql/mysql_test.go b/plugins/database/mysql/mysql_test.go new file mode 100644 index 000000000000..2b1f27291861 --- /dev/null +++ b/plugins/database/mysql/mysql_test.go @@ -0,0 +1,200 @@ +package mysql + +import ( + "database/sql" + "fmt" + "os" + "strings" + "sync" + "testing" + "time" + + "github.com/hashicorp/vault/builtin/logical/database/dbplugin" + "github.com/hashicorp/vault/plugins/helper/database/connutil" + dockertest "gopkg.in/ory-am/dockertest.v3" +) + +var ( + testMySQLImagePull sync.Once +) + +func prepareMySQLTestContainer(t *testing.T) (cleanup func(), retURL string) { + if os.Getenv("MYSQL_URL") != "" { + return func() {}, os.Getenv("MYSQL_URL") + } + + pool, err := dockertest.NewPool("") + if err != nil { + t.Fatalf("Failed to connect to docker: %s", err) + } + + resource, err := pool.Run("mysql", "latest", []string{"MYSQL_ROOT_PASSWORD=secret"}) + if err != nil { + t.Fatalf("Could not start local MySQL docker container: %s", err) + } + + cleanup = func() { + err := pool.Purge(resource) + if err != nil { + t.Fatalf("Failed to cleanup local container: %s", err) + } + } + + retURL = fmt.Sprintf("root:secret@(localhost:%s)/mysql?parseTime=true", resource.GetPort("3306/tcp")) + + // exponential backoff-retry + if err = pool.Retry(func() error { + var err error + var db *sql.DB + db, err = sql.Open("mysql", retURL) + if err != nil { + return err + } + return db.Ping() + }); err != nil { + t.Fatalf("Could not connect to MySQL docker container: %s", err) + } + + return +} + +func TestMySQL_Initialize(t *testing.T) { + cleanup, connURL := prepareMySQLTestContainer(t) + defer cleanup() + + connectionDetails := map[string]interface{}{ + "connection_url": connURL, + } + + db := New() + connProducer := db.ConnectionProducer.(*connutil.SQLConnectionProducer) + + err := db.Initialize(connectionDetails, true) + if err != nil { + t.Fatalf("err: %s", err) + } + + if !connProducer.Initialized { + t.Fatal("Database should be initalized") + } + + err = db.Close() + if err != nil { + t.Fatalf("err: %s", err) + } +} + +func TestMySQL_CreateUser(t *testing.T) { + cleanup, connURL := prepareMySQLTestContainer(t) + defer cleanup() + + connectionDetails := map[string]interface{}{ + "connection_url": connURL, + } + + db := New() + + err := db.Initialize(connectionDetails, true) + if err != nil { + t.Fatalf("err: %s", err) + } + + // Test with no configured Creation Statememt + _, _, err = db.CreateUser(dbplugin.Statements{}, "test", time.Now().Add(time.Minute)) + if err == nil { + t.Fatal("Expected error when no creation statement is provided") + } + + statements := dbplugin.Statements{ + CreationStatements: testMySQLRoleWildCard, + } + + username, password, err := db.CreateUser(statements, "test", time.Now().Add(time.Minute)) + if err != nil { + t.Fatalf("err: %s", err) + } + + if err := testCredsExist(t, connURL, username, password); err != nil { + t.Fatalf("Could not connect with new credentials: %s", err) + } +} + +func TestMySQL_RevokeUser(t *testing.T) { + cleanup, connURL := prepareMySQLTestContainer(t) + defer cleanup() + + connectionDetails := map[string]interface{}{ + "connection_url": connURL, + } + + db := New() + + err := db.Initialize(connectionDetails, true) + if err != nil { + t.Fatalf("err: %s", err) + } + + statements := dbplugin.Statements{ + CreationStatements: testMySQLRoleWildCard, + } + + username, password, err := db.CreateUser(statements, "test", time.Now().Add(time.Minute)) + if err != nil { + t.Fatalf("err: %s", err) + } + + if err := testCredsExist(t, connURL, username, password); err != nil { + t.Fatalf("Could not connect with new credentials: %s", err) + } + + // Test default revoke statememts + err = db.RevokeUser(statements, username) + if err != nil { + t.Fatalf("err: %s", err) + } + + if err := testCredsExist(t, connURL, username, password); err == nil { + t.Fatal("Credentials were not revoked") + } + + statements.CreationStatements = testMySQLRoleWildCard + username, password, err = db.CreateUser(statements, "test", time.Now().Add(time.Minute)) + if err != nil { + t.Fatalf("err: %s", err) + } + + if err := testCredsExist(t, connURL, username, password); err != nil { + t.Fatalf("Could not connect with new credentials: %s", err) + } + + // Test custom revoke statements + statements.RevocationStatements = testMySQLRevocationSQL + err = db.RevokeUser(statements, username) + if err != nil { + t.Fatalf("err: %s", err) + } + + if err := testCredsExist(t, connURL, username, password); err == nil { + t.Fatal("Credentials were not revoked") + } +} + +func testCredsExist(t testing.TB, connURL, username, password string) error { + // Log in with the new creds + connURL = strings.Replace(connURL, "root:secret", fmt.Sprintf("%s:%s", username, password), 1) + db, err := sql.Open("mysql", connURL) + if err != nil { + return err + } + defer db.Close() + return db.Ping() +} + +const testMySQLRoleWildCard = ` +CREATE USER '{{name}}'@'%' IDENTIFIED BY '{{password}}'; +GRANT SELECT ON *.* TO '{{name}}'@'%'; +` +const testMySQLRevocationSQL = ` +REVOKE ALL PRIVILEGES, GRANT OPTION FROM '{{name}}'@'%'; +DROP USER '{{name}}'@'%'; +` diff --git a/plugins/database/postgresql/postgresql-database-plugin/main.go b/plugins/database/postgresql/postgresql-database-plugin/main.go new file mode 100644 index 000000000000..9b9b813c4c19 --- /dev/null +++ b/plugins/database/postgresql/postgresql-database-plugin/main.go @@ -0,0 +1,16 @@ +package main + +import ( + "fmt" + "os" + + "github.com/hashicorp/vault/plugins/database/postgresql" +) + +func main() { + err := postgresql.Run() + if err != nil { + fmt.Println(err) + os.Exit(1) + } +} diff --git a/plugins/database/postgresql/postgresql.go b/plugins/database/postgresql/postgresql.go new file mode 100644 index 000000000000..b8449f54986c --- /dev/null +++ b/plugins/database/postgresql/postgresql.go @@ -0,0 +1,337 @@ +package postgresql + +import ( + "database/sql" + "fmt" + "strings" + "time" + + "github.com/hashicorp/vault/builtin/logical/database/dbplugin" + "github.com/hashicorp/vault/helper/strutil" + "github.com/hashicorp/vault/plugins/helper/database/connutil" + "github.com/hashicorp/vault/plugins/helper/database/credsutil" + "github.com/hashicorp/vault/plugins/helper/database/dbutil" + "github.com/lib/pq" +) + +const postgreSQLTypeName string = "postgres" + +func New() *PostgreSQL { + connProducer := &connutil.SQLConnectionProducer{} + connProducer.Type = postgreSQLTypeName + + credsProducer := &credsutil.SQLCredentialsProducer{ + DisplayNameLen: 4, + UsernameLen: 16, + } + + dbType := &PostgreSQL{ + ConnectionProducer: connProducer, + CredentialsProducer: credsProducer, + } + + return dbType +} + +// Run instatiates a PostgreSQL object, and runs the RPC server for the plugin +func Run() error { + dbType := New() + + dbplugin.NewPluginServer(dbType) + + return nil +} + +type PostgreSQL struct { + connutil.ConnectionProducer + credsutil.CredentialsProducer +} + +func (p *PostgreSQL) Type() (string, error) { + return postgreSQLTypeName, nil +} + +func (p *PostgreSQL) getConnection() (*sql.DB, error) { + db, err := p.Connection() + if err != nil { + return nil, err + } + + return db.(*sql.DB), nil +} + +func (p *PostgreSQL) CreateUser(statements dbplugin.Statements, usernamePrefix string, expiration time.Time) (username string, password string, err error) { + if statements.CreationStatements == "" { + return "", "", dbutil.ErrEmptyCreationStatement + } + + // Grab the lock + p.Lock() + defer p.Unlock() + + username, err = p.GenerateUsername(usernamePrefix) + if err != nil { + return "", "", err + } + + password, err = p.GeneratePassword() + if err != nil { + return "", "", err + } + + expirationStr, err := p.GenerateExpiration(expiration) + if err != nil { + return "", "", err + } + + // Get the connection + db, err := p.getConnection() + if err != nil { + return "", "", err + + } + + // Start a transaction + tx, err := db.Begin() + if err != nil { + return "", "", err + + } + defer func() { + tx.Rollback() + }() + // Return the secret + + // Execute each query + for _, query := range strutil.ParseArbitraryStringSlice(statements.CreationStatements, ";") { + query = strings.TrimSpace(query) + if len(query) == 0 { + continue + } + + stmt, err := tx.Prepare(dbutil.QueryHelper(query, map[string]string{ + "name": username, + "password": password, + "expiration": expirationStr, + })) + if err != nil { + return "", "", err + + } + defer stmt.Close() + if _, err := stmt.Exec(); err != nil { + return "", "", err + + } + } + + // Commit the transaction + if err := tx.Commit(); err != nil { + return "", "", err + + } + + return username, password, nil +} + +func (p *PostgreSQL) RenewUser(statements dbplugin.Statements, username string, expiration time.Time) error { + // Grab the lock + p.Lock() + defer p.Unlock() + + db, err := p.getConnection() + if err != nil { + return err + } + + expirationStr, err := p.GenerateExpiration(expiration) + if err != nil { + return err + } + + query := fmt.Sprintf( + "ALTER ROLE %s VALID UNTIL '%s';", + pq.QuoteIdentifier(username), + expirationStr) + + stmt, err := db.Prepare(query) + if err != nil { + return err + } + defer stmt.Close() + if _, err := stmt.Exec(); err != nil { + return err + } + + return nil +} + +func (p *PostgreSQL) RevokeUser(statements dbplugin.Statements, username string) error { + // Grab the lock + p.Lock() + defer p.Unlock() + + if statements.RevocationStatements == "" { + return p.defaultRevokeUser(username) + } + + return p.customRevokeUser(username, statements.RevocationStatements) +} + +func (p *PostgreSQL) customRevokeUser(username, revocationStmts string) error { + db, err := p.getConnection() + if err != nil { + return err + } + + tx, err := db.Begin() + if err != nil { + return err + } + defer func() { + tx.Rollback() + }() + + for _, query := range strutil.ParseArbitraryStringSlice(revocationStmts, ";") { + query = strings.TrimSpace(query) + if len(query) == 0 { + continue + } + + stmt, err := tx.Prepare(dbutil.QueryHelper(query, map[string]string{ + "name": username, + })) + if err != nil { + return err + } + defer stmt.Close() + + if _, err := stmt.Exec(); err != nil { + return err + } + } + + if err := tx.Commit(); err != nil { + return err + } + + return nil +} + +func (p *PostgreSQL) defaultRevokeUser(username string) error { + db, err := p.getConnection() + if err != nil { + return err + } + + // Check if the role exists + var exists bool + err = db.QueryRow("SELECT exists (SELECT rolname FROM pg_roles WHERE rolname=$1);", username).Scan(&exists) + if err != nil && err != sql.ErrNoRows { + return err + } + + if exists == false { + return nil + } + + // Query for permissions; we need to revoke permissions before we can drop + // the role + // This isn't done in a transaction because even if we fail along the way, + // we want to remove as much access as possible + stmt, err := db.Prepare("SELECT DISTINCT table_schema FROM information_schema.role_column_grants WHERE grantee=$1;") + if err != nil { + return err + } + defer stmt.Close() + + rows, err := stmt.Query(username) + if err != nil { + return err + } + defer rows.Close() + + const initialNumRevocations = 16 + revocationStmts := make([]string, 0, initialNumRevocations) + for rows.Next() { + var schema string + err = rows.Scan(&schema) + if err != nil { + // keep going; remove as many permissions as possible right now + continue + } + revocationStmts = append(revocationStmts, fmt.Sprintf( + `REVOKE ALL PRIVILEGES ON ALL TABLES IN SCHEMA %s FROM %s;`, + pq.QuoteIdentifier(schema), + pq.QuoteIdentifier(username))) + + revocationStmts = append(revocationStmts, fmt.Sprintf( + `REVOKE USAGE ON SCHEMA %s FROM %s;`, + pq.QuoteIdentifier(schema), + pq.QuoteIdentifier(username))) + } + + // for good measure, revoke all privileges and usage on schema public + revocationStmts = append(revocationStmts, fmt.Sprintf( + `REVOKE ALL PRIVILEGES ON ALL TABLES IN SCHEMA public FROM %s;`, + pq.QuoteIdentifier(username))) + + revocationStmts = append(revocationStmts, fmt.Sprintf( + "REVOKE ALL PRIVILEGES ON ALL SEQUENCES IN SCHEMA public FROM %s;", + pq.QuoteIdentifier(username))) + + revocationStmts = append(revocationStmts, fmt.Sprintf( + "REVOKE USAGE ON SCHEMA public FROM %s;", + pq.QuoteIdentifier(username))) + + // get the current database name so we can issue a REVOKE CONNECT for + // this username + var dbname sql.NullString + if err := db.QueryRow("SELECT current_database();").Scan(&dbname); err != nil { + return err + } + + if dbname.Valid { + revocationStmts = append(revocationStmts, fmt.Sprintf( + `REVOKE CONNECT ON DATABASE %s FROM %s;`, + pq.QuoteIdentifier(dbname.String), + pq.QuoteIdentifier(username))) + } + + // again, here, we do not stop on error, as we want to remove as + // many permissions as possible right now + var lastStmtError error + for _, query := range revocationStmts { + stmt, err := db.Prepare(query) + if err != nil { + lastStmtError = err + continue + } + defer stmt.Close() + _, err = stmt.Exec() + if err != nil { + lastStmtError = err + } + } + + // can't drop if not all privileges are revoked + if rows.Err() != nil { + return fmt.Errorf("could not generate revocation statements for all rows: %s", rows.Err()) + } + if lastStmtError != nil { + return fmt.Errorf("could not perform all revocation statements: %s", lastStmtError) + } + + // Drop this user + stmt, err = db.Prepare(fmt.Sprintf( + `DROP ROLE IF EXISTS %s;`, pq.QuoteIdentifier(username))) + if err != nil { + return err + } + defer stmt.Close() + if _, err := stmt.Exec(); err != nil { + return err + } + + return nil +} diff --git a/plugins/database/postgresql/postgresql_test.go b/plugins/database/postgresql/postgresql_test.go new file mode 100644 index 000000000000..c7ccc8ee8f13 --- /dev/null +++ b/plugins/database/postgresql/postgresql_test.go @@ -0,0 +1,308 @@ +package postgresql + +import ( + "database/sql" + "fmt" + "os" + "strings" + "sync" + "testing" + "time" + + "github.com/hashicorp/vault/builtin/logical/database/dbplugin" + "github.com/hashicorp/vault/plugins/helper/database/connutil" + dockertest "gopkg.in/ory-am/dockertest.v3" +) + +var ( + testPostgresImagePull sync.Once +) + +func preparePostgresTestContainer(t *testing.T) (cleanup func(), retURL string) { + if os.Getenv("PG_URL") != "" { + return func() {}, os.Getenv("PG_URL") + } + + pool, err := dockertest.NewPool("") + if err != nil { + t.Fatalf("Failed to connect to docker: %s", err) + } + + resource, err := pool.Run("postgres", "latest", []string{"POSTGRES_PASSWORD=secret", "POSTGRES_DB=database"}) + if err != nil { + t.Fatalf("Could not start local PostgreSQL docker container: %s", err) + } + + cleanup = func() { + err := pool.Purge(resource) + if err != nil { + t.Fatalf("Failed to cleanup local container: %s", err) + } + } + + retURL = fmt.Sprintf("postgres://postgres:secret@localhost:%s/database?sslmode=disable", resource.GetPort("5432/tcp")) + + // exponential backoff-retry + if err = pool.Retry(func() error { + var err error + var db *sql.DB + db, err = sql.Open("postgres", retURL) + if err != nil { + return err + } + return db.Ping() + }); err != nil { + t.Fatalf("Could not connect to PostgreSQL docker container: %s", err) + } + + return +} + +func TestPostgreSQL_Initialize(t *testing.T) { + cleanup, connURL := preparePostgresTestContainer(t) + defer cleanup() + + connectionDetails := map[string]interface{}{ + "connection_url": connURL, + } + + db := New() + connProducer := db.ConnectionProducer.(*connutil.SQLConnectionProducer) + + err := db.Initialize(connectionDetails, true) + if err != nil { + t.Fatalf("err: %s", err) + } + + if !connProducer.Initialized { + t.Fatal("Database should be initalized") + } + + err = db.Close() + if err != nil { + t.Fatalf("err: %s", err) + } +} + +func TestPostgreSQL_CreateUser(t *testing.T) { + cleanup, connURL := preparePostgresTestContainer(t) + defer cleanup() + + connectionDetails := map[string]interface{}{ + "connection_url": connURL, + } + + db := New() + err := db.Initialize(connectionDetails, true) + if err != nil { + t.Fatalf("err: %s", err) + } + + // Test with no configured Creation Statememt + _, _, err = db.CreateUser(dbplugin.Statements{}, "test", time.Now().Add(time.Minute)) + if err == nil { + t.Fatal("Expected error when no creation statement is provided") + } + + statements := dbplugin.Statements{ + CreationStatements: testPostgresRole, + } + + username, password, err := db.CreateUser(statements, "test", time.Now().Add(time.Minute)) + if err != nil { + t.Fatalf("err: %s", err) + } + + if err = testCredsExist(t, connURL, username, password); err != nil { + t.Fatalf("Could not connect with new credentials: %s", err) + } + + statements.CreationStatements = testPostgresReadOnlyRole + username, password, err = db.CreateUser(statements, "test", time.Now().Add(time.Minute)) + if err != nil { + t.Fatalf("err: %s", err) + } + + if err = testCredsExist(t, connURL, username, password); err != nil { + t.Fatalf("Could not connect with new credentials: %s", err) + } +} + +func TestPostgreSQL_RenewUser(t *testing.T) { + cleanup, connURL := preparePostgresTestContainer(t) + defer cleanup() + + connectionDetails := map[string]interface{}{ + "connection_url": connURL, + } + + db := New() + err := db.Initialize(connectionDetails, true) + if err != nil { + t.Fatalf("err: %s", err) + } + + statements := dbplugin.Statements{ + CreationStatements: testPostgresRole, + } + + username, password, err := db.CreateUser(statements, "test", time.Now().Add(2*time.Second)) + if err != nil { + t.Fatalf("err: %s", err) + } + + if err = testCredsExist(t, connURL, username, password); err != nil { + t.Fatalf("Could not connect with new credentials: %s", err) + } + + err = db.RenewUser(statements, username, time.Now().Add(time.Minute)) + if err != nil { + t.Fatalf("err: %s", err) + } + + // Sleep longer than the inital expiration time + time.Sleep(2 * time.Second) + + if err = testCredsExist(t, connURL, username, password); err != nil { + t.Fatalf("Could not connect with new credentials: %s", err) + } +} + +func TestPostgreSQL_RevokeUser(t *testing.T) { + cleanup, connURL := preparePostgresTestContainer(t) + defer cleanup() + + connectionDetails := map[string]interface{}{ + "connection_url": connURL, + } + + db := New() + err := db.Initialize(connectionDetails, true) + if err != nil { + t.Fatalf("err: %s", err) + } + + statements := dbplugin.Statements{ + CreationStatements: testPostgresRole, + } + + username, password, err := db.CreateUser(statements, "test", time.Now().Add(2*time.Second)) + if err != nil { + t.Fatalf("err: %s", err) + } + + if err = testCredsExist(t, connURL, username, password); err != nil { + t.Fatalf("Could not connect with new credentials: %s", err) + } + + // Test default revoke statememts + err = db.RevokeUser(statements, username) + if err != nil { + t.Fatalf("err: %s", err) + } + + if err := testCredsExist(t, connURL, username, password); err == nil { + t.Fatal("Credentials were not revoked") + } + + username, password, err = db.CreateUser(statements, "test", time.Now().Add(2*time.Second)) + if err != nil { + t.Fatalf("err: %s", err) + } + + if err = testCredsExist(t, connURL, username, password); err != nil { + t.Fatalf("Could not connect with new credentials: %s", err) + } + + // Test custom revoke statements + statements.RevocationStatements = defaultPostgresRevocationSQL + err = db.RevokeUser(statements, username) + if err != nil { + t.Fatalf("err: %s", err) + } + + if err := testCredsExist(t, connURL, username, password); err == nil { + t.Fatal("Credentials were not revoked") + } +} + +func testCredsExist(t testing.TB, connURL, username, password string) error { + // Log in with the new creds + connURL = strings.Replace(connURL, "postgres:secret", fmt.Sprintf("%s:%s", username, password), 1) + db, err := sql.Open("postgres", connURL) + if err != nil { + return err + } + defer db.Close() + return db.Ping() +} + +const testPostgresRole = ` +CREATE ROLE "{{name}}" WITH + LOGIN + PASSWORD '{{password}}' + VALID UNTIL '{{expiration}}'; +GRANT ALL PRIVILEGES ON ALL TABLES IN SCHEMA public TO "{{name}}"; +` + +const testPostgresReadOnlyRole = ` +CREATE ROLE "{{name}}" WITH + LOGIN + PASSWORD '{{password}}' + VALID UNTIL '{{expiration}}'; +GRANT SELECT ON ALL TABLES IN SCHEMA public TO "{{name}}"; +GRANT SELECT ON ALL SEQUENCES IN SCHEMA public TO "{{name}}"; +` + +const testPostgresBlockStatementRole = ` +DO $$ +BEGIN + IF NOT EXISTS (SELECT * FROM pg_catalog.pg_roles WHERE rolname='foo-role') THEN + CREATE ROLE "foo-role"; + CREATE SCHEMA IF NOT EXISTS foo AUTHORIZATION "foo-role"; + ALTER ROLE "foo-role" SET search_path = foo; + GRANT TEMPORARY ON DATABASE "postgres" TO "foo-role"; + GRANT ALL PRIVILEGES ON SCHEMA foo TO "foo-role"; + GRANT ALL PRIVILEGES ON ALL TABLES IN SCHEMA foo TO "foo-role"; + GRANT ALL PRIVILEGES ON ALL SEQUENCES IN SCHEMA foo TO "foo-role"; + GRANT ALL PRIVILEGES ON ALL FUNCTIONS IN SCHEMA foo TO "foo-role"; + END IF; +END +$$ + +CREATE ROLE "{{name}}" WITH LOGIN PASSWORD '{{password}}' VALID UNTIL '{{expiration}}'; +GRANT "foo-role" TO "{{name}}"; +ALTER ROLE "{{name}}" SET search_path = foo; +GRANT CONNECT ON DATABASE "postgres" TO "{{name}}"; +` + +var testPostgresBlockStatementRoleSlice = []string{ + ` +DO $$ +BEGIN + IF NOT EXISTS (SELECT * FROM pg_catalog.pg_roles WHERE rolname='foo-role') THEN + CREATE ROLE "foo-role"; + CREATE SCHEMA IF NOT EXISTS foo AUTHORIZATION "foo-role"; + ALTER ROLE "foo-role" SET search_path = foo; + GRANT TEMPORARY ON DATABASE "postgres" TO "foo-role"; + GRANT ALL PRIVILEGES ON SCHEMA foo TO "foo-role"; + GRANT ALL PRIVILEGES ON ALL TABLES IN SCHEMA foo TO "foo-role"; + GRANT ALL PRIVILEGES ON ALL SEQUENCES IN SCHEMA foo TO "foo-role"; + GRANT ALL PRIVILEGES ON ALL FUNCTIONS IN SCHEMA foo TO "foo-role"; + END IF; +END +$$ +`, + `CREATE ROLE "{{name}}" WITH LOGIN PASSWORD '{{password}}' VALID UNTIL '{{expiration}}';`, + `GRANT "foo-role" TO "{{name}}";`, + `ALTER ROLE "{{name}}" SET search_path = foo;`, + `GRANT CONNECT ON DATABASE "postgres" TO "{{name}}";`, +} + +const defaultPostgresRevocationSQL = ` +REVOKE ALL PRIVILEGES ON ALL TABLES IN SCHEMA public FROM "{{name}}"; +REVOKE ALL PRIVILEGES ON ALL SEQUENCES IN SCHEMA public FROM "{{name}}"; +REVOKE USAGE ON SCHEMA public FROM "{{name}}"; + +DROP ROLE IF EXISTS "{{name}}"; +` diff --git a/plugins/helper/database/connutil/cassandra.go b/plugins/helper/database/connutil/cassandra.go new file mode 100644 index 000000000000..305bc6e3d072 --- /dev/null +++ b/plugins/helper/database/connutil/cassandra.go @@ -0,0 +1,172 @@ +package connutil + +import ( + "crypto/tls" + "fmt" + "strings" + "sync" + "time" + + "github.com/mitchellh/mapstructure" + + "github.com/gocql/gocql" + "github.com/hashicorp/vault/helper/certutil" + "github.com/hashicorp/vault/helper/tlsutil" +) + +// CassandraConnectionProducer implements ConnectionProducer and provides an +// interface for cassandra databases to make connections. +type CassandraConnectionProducer struct { + Hosts string `json:"hosts" structs:"hosts" mapstructure:"hosts"` + Username string `json:"username" structs:"username" mapstructure:"username"` + Password string `json:"password" structs:"password" mapstructure:"password"` + TLS bool `json:"tls" structs:"tls" mapstructure:"tls"` + InsecureTLS bool `json:"insecure_tls" structs:"insecure_tls" mapstructure:"insecure_tls"` + Certificate string `json:"certificate" structs:"certificate" mapstructure:"certificate"` + PrivateKey string `json:"private_key" structs:"private_key" mapstructure:"private_key"` + IssuingCA string `json:"issuing_ca" structs:"issuing_ca" mapstructure:"issuing_ca"` + ProtocolVersion int `json:"protocol_version" structs:"protocol_version" mapstructure:"protocol_version"` + ConnectTimeout int `json:"connect_timeout" structs:"connect_timeout" mapstructure:"connect_timeout"` + TLSMinVersion string `json:"tls_min_version" structs:"tls_min_version" mapstructure:"tls_min_version"` + Consistency string `json:"consistency" structs:"consistency" mapstructure:"consistency"` + + Initialized bool + session *gocql.Session + sync.Mutex +} + +func (c *CassandraConnectionProducer) Initialize(conf map[string]interface{}, verifyConnection bool) error { + c.Lock() + defer c.Unlock() + + err := mapstructure.Decode(conf, c) + if err != nil { + return err + } + c.Initialized = true + + if verifyConnection { + if _, err := c.connection(); err != nil { + return fmt.Errorf("error Initalizing Connection: %s", err) + } + } + return nil +} + +func (c *CassandraConnectionProducer) connection() (interface{}, error) { + if !c.Initialized { + return nil, errNotInitialized + } + + // If we already have a DB, return it + if c.session != nil { + return c.session, nil + } + + session, err := c.createSession() + if err != nil { + return nil, err + } + + // Store the session in backend for reuse + c.session = session + + return session, nil +} + +func (c *CassandraConnectionProducer) Close() error { + // Grab the write lock + c.Lock() + defer c.Unlock() + + if c.session != nil { + c.session.Close() + } + + c.session = nil + + return nil +} + +func (c *CassandraConnectionProducer) createSession() (*gocql.Session, error) { + clusterConfig := gocql.NewCluster(strings.Split(c.Hosts, ",")...) + clusterConfig.Authenticator = gocql.PasswordAuthenticator{ + Username: c.Username, + Password: c.Password, + } + + clusterConfig.ProtoVersion = c.ProtocolVersion + if clusterConfig.ProtoVersion == 0 { + clusterConfig.ProtoVersion = 2 + } + + clusterConfig.Timeout = time.Duration(c.ConnectTimeout) * time.Second + + if c.TLS { + var tlsConfig *tls.Config + if len(c.Certificate) > 0 || len(c.IssuingCA) > 0 { + if len(c.Certificate) > 0 && len(c.PrivateKey) == 0 { + return nil, fmt.Errorf("Found certificate for TLS authentication but no private key") + } + + certBundle := &certutil.CertBundle{} + if len(c.Certificate) > 0 { + certBundle.Certificate = c.Certificate + certBundle.PrivateKey = c.PrivateKey + } + if len(c.IssuingCA) > 0 { + certBundle.IssuingCA = c.IssuingCA + } + + parsedCertBundle, err := certBundle.ToParsedCertBundle() + if err != nil { + return nil, fmt.Errorf("failed to parse certificate bundle: %s", err) + } + + tlsConfig, err = parsedCertBundle.GetTLSConfig(certutil.TLSClient) + if err != nil || tlsConfig == nil { + return nil, fmt.Errorf("failed to get TLS configuration: tlsConfig:%#v err:%v", tlsConfig, err) + } + tlsConfig.InsecureSkipVerify = c.InsecureTLS + + if c.TLSMinVersion != "" { + var ok bool + tlsConfig.MinVersion, ok = tlsutil.TLSLookup[c.TLSMinVersion] + if !ok { + return nil, fmt.Errorf("invalid 'tls_min_version' in config") + } + } else { + // MinVersion was not being set earlier. Reset it to + // zero to gracefully handle upgrades. + tlsConfig.MinVersion = 0 + } + } + + clusterConfig.SslOpts = &gocql.SslOptions{ + Config: *tlsConfig, + } + } + + session, err := clusterConfig.CreateSession() + if err != nil { + return nil, fmt.Errorf("error creating session: %s", err) + } + + // Set consistency + if c.Consistency != "" { + consistencyValue, err := gocql.ParseConsistencyWrapper(c.Consistency) + if err != nil { + return nil, err + } + + session.SetConsistency(consistencyValue) + } + + // Verify the info + err = session.Query(`LIST USERS`).Exec() + if err != nil { + return nil, fmt.Errorf("error validating connection info: %s", err) + } + + return session, nil +} diff --git a/plugins/helper/database/connutil/connutil.go b/plugins/helper/database/connutil/connutil.go new file mode 100644 index 000000000000..6de3299e3899 --- /dev/null +++ b/plugins/helper/database/connutil/connutil.go @@ -0,0 +1,21 @@ +package connutil + +import ( + "errors" + "sync" +) + +var ( + errNotInitialized = errors.New("connection has not been initalized") +) + +// ConnectionProducer can be used as an embeded interface in the DatabaseType +// definition. It implements the methods dealing with individual database +// connections and is used in all the builtin database types. +type ConnectionProducer interface { + Close() error + Initialize(map[string]interface{}, bool) error + Connection() (interface{}, error) + + sync.Locker +} diff --git a/plugins/helper/database/connutil/sql.go b/plugins/helper/database/connutil/sql.go new file mode 100644 index 000000000000..0bfc5f9f684f --- /dev/null +++ b/plugins/helper/database/connutil/sql.go @@ -0,0 +1,131 @@ +package connutil + +import ( + "database/sql" + "fmt" + "strings" + "sync" + "time" + + // Import sql drivers + _ "github.com/denisenkom/go-mssqldb" + _ "github.com/go-sql-driver/mysql" + _ "github.com/lib/pq" + "github.com/mitchellh/mapstructure" +) + +// SQLConnectionProducer implements ConnectionProducer and provides a generic producer for most sql databases +type SQLConnectionProducer struct { + ConnectionURL string `json:"connection_url" structs:"connection_url" mapstructure:"connection_url"` + MaxOpenConnections int `json:"max_open_connections" structs:"max_open_connections" mapstructure:"max_open_connections"` + MaxIdleConnections int `json:"max_idle_connections" structs:"max_idle_connections" mapstructure:"max_idle_connections"` + MaxConnectionLifetimeRaw string `json:"max_connection_lifetime" structs:"max_connection_lifetime" mapstructure:"max_connection_lifetime"` + + Type string + MaxConnectionLifetime time.Duration + Initialized bool + db *sql.DB + sync.Mutex +} + +func (c *SQLConnectionProducer) Initialize(conf map[string]interface{}, verifyConnection bool) error { + c.Lock() + defer c.Unlock() + + err := mapstructure.Decode(conf, c) + if err != nil { + return err + } + + if c.MaxOpenConnections == 0 { + c.MaxOpenConnections = 2 + } + + if c.MaxIdleConnections == 0 { + c.MaxIdleConnections = c.MaxOpenConnections + } + if c.MaxIdleConnections > c.MaxOpenConnections { + c.MaxIdleConnections = c.MaxOpenConnections + } + if c.MaxConnectionLifetimeRaw == "" { + c.MaxConnectionLifetimeRaw = "0s" + } + + c.MaxConnectionLifetime, err = time.ParseDuration(c.MaxConnectionLifetimeRaw) + if err != nil { + return fmt.Errorf("invalid max_connection_lifetime: %s", err) + } + + if verifyConnection { + if _, err := c.Connection(); err != nil { + return fmt.Errorf("error initalizing connection: %s", err) + } + + if err := c.db.Ping(); err != nil { + return fmt.Errorf("error initalizing connection: %s", err) + } + } + + c.Initialized = true + + return nil +} + +func (c *SQLConnectionProducer) Connection() (interface{}, error) { + // If we already have a DB, test it and return + if c.db != nil { + if err := c.db.Ping(); err == nil { + return c.db, nil + } + // If the ping was unsuccessful, close it and ignore errors as we'll be + // reestablishing anyways + c.db.Close() + } + + // For mssql backend, switch to sqlserver instead + dbType := c.Type + if c.Type == "mssql" { + dbType = "sqlserver" + } + + // Otherwise, attempt to make connection + conn := c.ConnectionURL + + // Ensure timezone is set to UTC for all the conenctions + if strings.HasPrefix(conn, "postgres://") || strings.HasPrefix(conn, "postgresql://") { + if strings.Contains(conn, "?") { + conn += "&timezone=utc" + } else { + conn += "?timezone=utc" + } + } + + var err error + c.db, err = sql.Open(dbType, conn) + if err != nil { + return nil, err + } + + // Set some connection pool settings. We don't need much of this, + // since the request rate shouldn't be high. + c.db.SetMaxOpenConns(c.MaxOpenConnections) + c.db.SetMaxIdleConns(c.MaxIdleConnections) + c.db.SetConnMaxLifetime(c.MaxConnectionLifetime) + + return c.db, nil +} + +// Close attempts to close the connection +func (c *SQLConnectionProducer) Close() error { + // Grab the write lock + c.Lock() + defer c.Unlock() + + if c.db != nil { + c.db.Close() + } + + c.db = nil + + return nil +} diff --git a/plugins/helper/database/credsutil/cassandra.go b/plugins/helper/database/credsutil/cassandra.go new file mode 100644 index 000000000000..7ab5630b5809 --- /dev/null +++ b/plugins/helper/database/credsutil/cassandra.go @@ -0,0 +1,37 @@ +package credsutil + +import ( + "fmt" + "strings" + "time" + + uuid "github.com/hashicorp/go-uuid" +) + +// CassandraCredentialsProducer implements CredentialsProducer and provides an +// interface for cassandra databases to generate user information. +type CassandraCredentialsProducer struct{} + +func (ccp *CassandraCredentialsProducer) GenerateUsername(displayName string) (string, error) { + userUUID, err := uuid.GenerateUUID() + if err != nil { + return "", err + } + username := fmt.Sprintf("vault_%s_%s_%d", displayName, userUUID, time.Now().Unix()) + username = strings.Replace(username, "-", "_", -1) + + return username, nil +} + +func (ccp *CassandraCredentialsProducer) GeneratePassword() (string, error) { + password, err := uuid.GenerateUUID() + if err != nil { + return "", err + } + + return password, nil +} + +func (ccp *CassandraCredentialsProducer) GenerateExpiration(ttl time.Time) (string, error) { + return "", nil +} diff --git a/plugins/helper/database/credsutil/credsutil.go b/plugins/helper/database/credsutil/credsutil.go new file mode 100644 index 000000000000..7f388a0f7678 --- /dev/null +++ b/plugins/helper/database/credsutil/credsutil.go @@ -0,0 +1,12 @@ +package credsutil + +import "time" + +// CredentialsProducer can be used as an embeded interface in the DatabaseType +// definition. It implements the methods for generating user information for a +// particular database type and is used in all the builtin database types. +type CredentialsProducer interface { + GenerateUsername(displayName string) (string, error) + GeneratePassword() (string, error) + GenerateExpiration(ttl time.Time) (string, error) +} diff --git a/plugins/helper/database/credsutil/sql.go b/plugins/helper/database/credsutil/sql.go new file mode 100644 index 000000000000..23e98102f3e8 --- /dev/null +++ b/plugins/helper/database/credsutil/sql.go @@ -0,0 +1,43 @@ +package credsutil + +import ( + "fmt" + "time" + + uuid "github.com/hashicorp/go-uuid" +) + +// SQLCredentialsProducer implements CredentialsProducer and provides a generic credentials producer for most sql database types. +type SQLCredentialsProducer struct { + DisplayNameLen int + UsernameLen int +} + +func (scp *SQLCredentialsProducer) GenerateUsername(displayName string) (string, error) { + if scp.DisplayNameLen > 0 && len(displayName) > scp.DisplayNameLen { + displayName = displayName[:scp.DisplayNameLen] + } + userUUID, err := uuid.GenerateUUID() + if err != nil { + return "", err + } + username := fmt.Sprintf("%s-%s", displayName, userUUID) + if scp.UsernameLen > 0 && len(username) > scp.UsernameLen { + username = username[:scp.UsernameLen] + } + + return username, nil +} + +func (scp *SQLCredentialsProducer) GeneratePassword() (string, error) { + password, err := uuid.GenerateUUID() + if err != nil { + return "", err + } + + return password, nil +} + +func (scp *SQLCredentialsProducer) GenerateExpiration(ttl time.Time) (string, error) { + return ttl.Format("2006-01-02 15:04:05-0700"), nil +} diff --git a/plugins/helper/database/dbutil/dbutil.go b/plugins/helper/database/dbutil/dbutil.go new file mode 100644 index 000000000000..e80273b7fb2c --- /dev/null +++ b/plugins/helper/database/dbutil/dbutil.go @@ -0,0 +1,20 @@ +package dbutil + +import ( + "errors" + "fmt" + "strings" +) + +var ( + ErrEmptyCreationStatement = errors.New("empty creation statements") +) + +// Query templates a query for us. +func QueryHelper(tpl string, data map[string]string) string { + for k, v := range data { + tpl = strings.Replace(tpl, fmt.Sprintf("{{%s}}", k), v, -1) + } + + return tpl +} From ea4173406e2804dd1a0dce071125eccd67ecb6dd Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Thu, 13 Apr 2017 14:30:15 -0700 Subject: [PATCH 079/152] Move mssql to be an acceptance test --- plugins/database/mssql/mssql_test.go | 62 +++++++--------------------- 1 file changed, 14 insertions(+), 48 deletions(-) diff --git a/plugins/database/mssql/mssql_test.go b/plugins/database/mssql/mssql_test.go index bc182f26fd4c..2bca0a7b85d6 100644 --- a/plugins/database/mssql/mssql_test.go +++ b/plugins/database/mssql/mssql_test.go @@ -11,56 +11,17 @@ import ( "github.com/hashicorp/vault/builtin/logical/database/dbplugin" "github.com/hashicorp/vault/plugins/helper/database/connutil" - dockertest "gopkg.in/ory-am/dockertest.v3" ) var ( testMSQLImagePull sync.Once ) -func prepareMSSQLTestContainer(t *testing.T) (cleanup func(), retURL string) { - if os.Getenv("MSSQL_URL") != "" { - return func() {}, os.Getenv("MSSQL_URL") - } - - pool, err := dockertest.NewPool("") - if err != nil { - t.Fatalf("Failed to connect to docker: %s", err) - } - - resource, err := pool.Run("microsoft/mssql-server-linux", "latest", []string{"ACCEPT_EULA=Y", "SA_PASSWORD=yourStrong(!)Password"}) - if err != nil { - t.Fatalf("Could not start local MSSQL docker container: %s", err) - } - - cleanup = func() { - err := pool.Purge(resource) - if err != nil { - t.Fatalf("Failed to cleanup local container: %s", err) - } - } - - retURL = fmt.Sprintf("sqlserver://sa:yourStrong(!)Password@localhost:%s", resource.GetPort("1433/tcp")) - - // exponential backoff-retry - if err = pool.Retry(func() error { - var err error - var db *sql.DB - db, err = sql.Open("mssql", retURL) - if err != nil { - return err - } - return db.Ping() - }); err != nil { - t.Fatalf("Could not connect to MSSQL docker container: %s", err) - } - - return -} - func TestMSSQL_Initialize(t *testing.T) { - cleanup, connURL := prepareMSSQLTestContainer(t) - defer cleanup() + if os.Getenv("MSSQL_URL") == "" { + return + } + connURL := os.Getenv("MSSQL_URL") connectionDetails := map[string]interface{}{ "connection_url": connURL, @@ -85,8 +46,10 @@ func TestMSSQL_Initialize(t *testing.T) { } func TestMSSQL_CreateUser(t *testing.T) { - cleanup, connURL := prepareMSSQLTestContainer(t) - defer cleanup() + if os.Getenv("MSSQL_URL") == "" { + return + } + connURL := os.Getenv("MSSQL_URL") connectionDetails := map[string]interface{}{ "connection_url": connURL, @@ -119,8 +82,10 @@ func TestMSSQL_CreateUser(t *testing.T) { } func TestMSSQL_RevokeUser(t *testing.T) { - cleanup, connURL := prepareMSSQLTestContainer(t) - defer cleanup() + if os.Getenv("MSSQL_URL") == "" { + return + } + connURL := os.Getenv("MSSQL_URL") connectionDetails := map[string]interface{}{ "connection_url": connURL, @@ -158,7 +123,8 @@ func TestMSSQL_RevokeUser(t *testing.T) { func testCredsExist(t testing.TB, connURL, username, password string) error { // Log in with the new creds - connURL = strings.Replace(connURL, "sa:yourStrong(!)Password", fmt.Sprintf("%s:%s", username, password), 1) + parts := strings.Split(connURL, "@") + connURL = fmt.Sprintf("sqlserver://%s:%s@%s", username, password, parts[1]) db, err := sql.Open("mssql", connURL) if err != nil { return err From 1f6bf2900aff25c52450212e9eeb33b930bea973 Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Thu, 13 Apr 2017 14:40:59 -0700 Subject: [PATCH 080/152] Only run mssql acceptance test when running as VAULT_ACC=1 --- plugins/database/mssql/mssql_test.go | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/plugins/database/mssql/mssql_test.go b/plugins/database/mssql/mssql_test.go index 2bca0a7b85d6..512033bd76ee 100644 --- a/plugins/database/mssql/mssql_test.go +++ b/plugins/database/mssql/mssql_test.go @@ -18,7 +18,7 @@ var ( ) func TestMSSQL_Initialize(t *testing.T) { - if os.Getenv("MSSQL_URL") == "" { + if os.Getenv("MSSQL_URL") == "" || os.Getenv("VAULT_ACC") != "1" { return } connURL := os.Getenv("MSSQL_URL") @@ -46,7 +46,7 @@ func TestMSSQL_Initialize(t *testing.T) { } func TestMSSQL_CreateUser(t *testing.T) { - if os.Getenv("MSSQL_URL") == "" { + if os.Getenv("MSSQL_URL") == "" || os.Getenv("VAULT_ACC") != "1" { return } connURL := os.Getenv("MSSQL_URL") @@ -82,7 +82,7 @@ func TestMSSQL_CreateUser(t *testing.T) { } func TestMSSQL_RevokeUser(t *testing.T) { - if os.Getenv("MSSQL_URL") == "" { + if os.Getenv("MSSQL_URL") == "" || os.Getenv("VAULT_ACC") != "1" { return } connURL := os.Getenv("MSSQL_URL") From 370dd2d2f20d53b26d5e098e9e0cd3724d1b2fe0 Mon Sep 17 00:00:00 2001 From: Chris Hoffman Date: Tue, 18 Apr 2017 17:32:08 -0400 Subject: [PATCH 081/152] Adding explicit database to sp_msloginmappings call (#2611) --- plugins/database/mssql/mssql.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/plugins/database/mssql/mssql.go b/plugins/database/mssql/mssql.go index 567a095b664a..b0e0ab6d4198 100644 --- a/plugins/database/mssql/mssql.go +++ b/plugins/database/mssql/mssql.go @@ -187,7 +187,7 @@ func (m *MSSQL) RevokeUser(statements dbplugin.Statements, username string) erro // we need to drop the database users before we can drop the login and the role // This isn't done in a transaction because even if we fail along the way, // we want to remove as much access as possible - stmt, err := db.Prepare(fmt.Sprintf("EXEC sp_msloginmappings '%s';", username)) + stmt, err := db.Prepare(fmt.Sprintf("EXEC master.dbo.sp_msloginmappings '%s';", username)) if err != nil { return err } From 8b7fa73f9d4938186f0c2e1e04f1f117c1e63f8e Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Wed, 19 Apr 2017 11:19:29 -0700 Subject: [PATCH 082/152] Fix cassandra deps breakage --- plugins/helper/database/connutil/cassandra.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/plugins/helper/database/connutil/cassandra.go b/plugins/helper/database/connutil/cassandra.go index 305bc6e3d072..028c6814fb83 100644 --- a/plugins/helper/database/connutil/cassandra.go +++ b/plugins/helper/database/connutil/cassandra.go @@ -143,7 +143,7 @@ func (c *CassandraConnectionProducer) createSession() (*gocql.Session, error) { } clusterConfig.SslOpts = &gocql.SslOptions{ - Config: *tlsConfig, + Config: tlsConfig, } } From d9ce189b33198c83622d13e89b4ba935198e6236 Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Wed, 19 Apr 2017 15:46:07 -0700 Subject: [PATCH 083/152] Use the same TLS cert for the server and client --- helper/pluginutil/runner.go | 6 +- helper/pluginutil/tls.go | 110 +++++++----------------------------- helper/strutil/strutil.go | 13 +++++ 3 files changed, 36 insertions(+), 93 deletions(-) diff --git a/helper/pluginutil/runner.go b/helper/pluginutil/runner.go index a57abad0edc5..bbc5ab99b752 100644 --- a/helper/pluginutil/runner.go +++ b/helper/pluginutil/runner.go @@ -51,20 +51,20 @@ type PluginRunner struct { // plugin. func (r *PluginRunner) Run(wrapper Wrapper, pluginMap map[string]plugin.Plugin, hs plugin.HandshakeConfig, env []string) (*plugin.Client, error) { // Get a CA TLS Certificate - CACertBytes, CACert, CAKey, err := GenerateCACert() + certBytes, key, err := GenerateCert() if err != nil { return nil, err } // Use CA to sign a client cert and return a configured TLS config - clientTLSConfig, err := CreateClientTLSConfig(CACert, CAKey) + clientTLSConfig, err := CreateClientTLSConfig(certBytes, key) if err != nil { return nil, err } // Use CA to sign a server cert and wrap the values in a response wrapped // token. - wrapToken, err := WrapServerConfig(wrapper, CACertBytes, CACert, CAKey) + wrapToken, err := WrapServerConfig(wrapper, certBytes, key) if err != nil { return nil, err } diff --git a/helper/pluginutil/tls.go b/helper/pluginutil/tls.go index c7aa42ee608e..d4c0946e4fbf 100644 --- a/helper/pluginutil/tls.go +++ b/helper/pluginutil/tls.go @@ -29,58 +29,19 @@ var ( PluginUnwrapTokenEnv = "VAULT_UNWRAP_TOKEN" ) -// GenerateCACert returns a CA cert used to later sign the certificates for the -// plugin client and server. -func GenerateCACert() ([]byte, *x509.Certificate, *ecdsa.PrivateKey, error) { +// generateSignedCert is used internally to create certificates for the plugin +// client and server. These certs are signed by the given CA Cert and Key. +func GenerateCert() ([]byte, *ecdsa.PrivateKey, error) { key, err := ecdsa.GenerateKey(elliptic.P521(), rand.Reader) if err != nil { - return nil, nil, nil, err + return nil, nil, err } host, err := uuid.GenerateUUID() if err != nil { - return nil, nil, nil, err - } - host = "localhost" - template := &x509.Certificate{ - Subject: pkix.Name{ - CommonName: host, - }, - DNSNames: []string{host}, - ExtKeyUsage: []x509.ExtKeyUsage{ - x509.ExtKeyUsageServerAuth, - x509.ExtKeyUsageClientAuth, - }, - KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment | x509.KeyUsageKeyAgreement | x509.KeyUsageCertSign, - SerialNumber: big.NewInt(mathrand.Int63()), - NotBefore: time.Now().Add(-30 * time.Second), - // 30 years of single-active uptime ought to be enough for anybody - NotAfter: time.Now().Add(262980 * time.Hour), - BasicConstraintsValid: true, - IsCA: true, - } - - certBytes, err := x509.CreateCertificate(rand.Reader, template, template, key.Public(), key) - if err != nil { - return nil, nil, nil, fmt.Errorf("unable to generate replicated cluster certificate: %v", err) + return nil, nil, err } - caCert, err := x509.ParseCertificate(certBytes) - if err != nil { - return nil, nil, nil, fmt.Errorf("error parsing generated replication certificate: %v", err) - } - - return certBytes, caCert, key, nil -} - -// generateSignedCert is used internally to create certificates for the plugin -// client and server. These certs are signed by the given CA Cert and Key. -func generateSignedCert(CACert *x509.Certificate, CAKey *ecdsa.PrivateKey) ([]byte, *x509.Certificate, *ecdsa.PrivateKey, error) { - host, err := uuid.GenerateUUID() - if err != nil { - return nil, nil, nil, err - } - host = "localhost" template := &x509.Certificate{ Subject: pkix.Name{ CommonName: host, @@ -94,48 +55,38 @@ func generateSignedCert(CACert *x509.Certificate, CAKey *ecdsa.PrivateKey) ([]by SerialNumber: big.NewInt(mathrand.Int63()), NotBefore: time.Now().Add(-30 * time.Second), NotAfter: time.Now().Add(262980 * time.Hour), + IsCA: true, } - clientKey, err := ecdsa.GenerateKey(elliptic.P521(), rand.Reader) - if err != nil { - return nil, nil, nil, errwrap.Wrapf("error generating client key: {{err}}", err) - } - - certBytes, err := x509.CreateCertificate(rand.Reader, template, CACert, clientKey.Public(), CAKey) - if err != nil { - return nil, nil, nil, errwrap.Wrapf("unable to generate client certificate: {{err}}", err) - } - - clientCert, err := x509.ParseCertificate(certBytes) + certBytes, err := x509.CreateCertificate(rand.Reader, template, template, key.Public(), key) if err != nil { - return nil, nil, nil, fmt.Errorf("error parsing generated replication certificate: %v", err) + return nil, nil, errwrap.Wrapf("unable to generate client certificate: {{err}}", err) } - return certBytes, clientCert, clientKey, nil + return certBytes, key, nil } // CreateClientTLSConfig creates a signed certificate and returns a configured // TLS config. -func CreateClientTLSConfig(CACert *x509.Certificate, CAKey *ecdsa.PrivateKey) (*tls.Config, error) { - clientCertBytes, clientCert, clientKey, err := generateSignedCert(CACert, CAKey) +func CreateClientTLSConfig(certBytes []byte, key *ecdsa.PrivateKey) (*tls.Config, error) { + clientCert, err := x509.ParseCertificate(certBytes) if err != nil { - return nil, err + return nil, fmt.Errorf("error parsing generated plugin certificate: %v", err) } cert := tls.Certificate{ - Certificate: [][]byte{clientCertBytes}, - PrivateKey: clientKey, + Certificate: [][]byte{certBytes}, + PrivateKey: key, Leaf: clientCert, } clientCertPool := x509.NewCertPool() - clientCertPool.AddCert(CACert) + clientCertPool.AddCert(clientCert) tlsConfig := &tls.Config{ Certificates: []tls.Certificate{cert}, RootCAs: clientCertPool, - ClientCAs: clientCertPool, - ServerName: CACert.Subject.CommonName, + ServerName: clientCert.Subject.CommonName, MinVersion: tls.VersionTLS12, } @@ -146,19 +97,14 @@ func CreateClientTLSConfig(CACert *x509.Certificate, CAKey *ecdsa.PrivateKey) (* // WrapServerConfig is used to create a server certificate and private key, then // wrap them in an unwrap token for later retrieval by the plugin. -func WrapServerConfig(sys Wrapper, CACertBytes []byte, CACert *x509.Certificate, CAKey *ecdsa.PrivateKey) (string, error) { - serverCertBytes, _, serverKey, err := generateSignedCert(CACert, CAKey) - if err != nil { - return "", err - } - rawKey, err := x509.MarshalECPrivateKey(serverKey) +func WrapServerConfig(sys Wrapper, certBytes []byte, key *ecdsa.PrivateKey) (string, error) { + rawKey, err := x509.MarshalECPrivateKey(key) if err != nil { return "", err } wrapToken, err := sys.ResponseWrapData(map[string]interface{}{ - "CACert": CACertBytes, - "ServerCert": serverCertBytes, + "ServerCert": certBytes, "ServerKey": rawKey, }, time.Second*10, true) @@ -217,22 +163,6 @@ func VaultPluginTLSProvider() (*tls.Config, error) { return nil, errors.New("error during token unwrap request secret is nil") } - // Retrieve and parse the CA Certificate - CABytesRaw, ok := secret.Data["CACert"].(string) - if !ok { - return nil, errors.New("error unmarshalling CA certificate") - } - - CABytes, err := base64.StdEncoding.DecodeString(CABytesRaw) - if err != nil { - return nil, fmt.Errorf("error parsing certificate: %v", err) - } - - CACert, err := x509.ParseCertificate(CABytes) - if err != nil { - return nil, fmt.Errorf("error parsing certificate: %v", err) - } - // Retrieve and parse the server's certificate serverCertBytesRaw, ok := secret.Data["ServerCert"].(string) if !ok { @@ -267,7 +197,7 @@ func VaultPluginTLSProvider() (*tls.Config, error) { // Add CA cert to the cert pool caCertPool := x509.NewCertPool() - caCertPool.AddCert(CACert) + caCertPool.AddCert(serverCert) // Build a certificate object out of the server's cert and private key. cert := tls.Certificate{ diff --git a/helper/strutil/strutil.go b/helper/strutil/strutil.go index 7c7f64d3da88..986928e0ed4d 100644 --- a/helper/strutil/strutil.go +++ b/helper/strutil/strutil.go @@ -29,6 +29,19 @@ func StrListSubset(super, sub []string) bool { return true } +// Parses a comma separated list of strings into a slice of strings. +// The return slice will be sorted and will not contain duplicate or +// empty items. +func ParseDedupAndSortStrings(input string, sep string) []string { + input = strings.TrimSpace(input) + parsed := []string{} + if input == "" { + // Don't return nil + return parsed + } + return RemoveDuplicates(strings.Split(input, sep), false) +} + // Parses a comma separated list of strings into a slice of strings. // The return slice will be sorted and will not contain duplicate or // empty items. The values will be converted to lower case. From f1fa617e03fb6700d5370efb995e6021617f68dc Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Thu, 20 Apr 2017 18:46:41 -0700 Subject: [PATCH 084/152] Calls to builtin plugins now go directly to the implementation instead of go-plugin --- builtin/logical/database/backend_test.go | 5 +- builtin/logical/database/dbplugin/plugin.go | 27 ++++++-- cli/commands.go | 6 -- command/plugin_exec.go | 66 ------------------- command/server.go | 61 +++++------------ helper/builtinplugins/builtin.go | 12 ++-- helper/pluginutil/runner.go | 11 ++-- plugins/database/mysql/mysql.go | 11 ++-- plugins/database/mysql/mysql_test.go | 9 ++- plugins/database/postgresql/postgresql.go | 12 ++-- .../database/postgresql/postgresql_test.go | 13 ++-- vault/core.go | 12 +--- vault/plugin_catalog.go | 28 ++++---- 13 files changed, 94 insertions(+), 179 deletions(-) delete mode 100644 command/plugin_exec.go diff --git a/builtin/logical/database/backend_test.go b/builtin/logical/database/backend_test.go index 2615577fdb51..2ece767fcd3e 100644 --- a/builtin/logical/database/backend_test.go +++ b/builtin/logical/database/backend_test.go @@ -11,10 +11,10 @@ import ( "testing" "github.com/hashicorp/vault/builtin/logical/database/dbplugin" - "github.com/hashicorp/vault/helper/builtinplugins" "github.com/hashicorp/vault/helper/pluginutil" "github.com/hashicorp/vault/http" "github.com/hashicorp/vault/logical" + "github.com/hashicorp/vault/plugins/database/postgresql" "github.com/hashicorp/vault/vault" "github.com/lib/pq" "github.com/mitchellh/mapstructure" @@ -91,8 +91,7 @@ func TestBackend_PluginMain(t *testing.T) { return } - f, _ := builtinplugins.BuiltinPlugins.Get("postgresql-database-plugin") - f() + postgresql.Run() } func TestBackend_config_connection(t *testing.T) { diff --git a/builtin/logical/database/dbplugin/plugin.go b/builtin/logical/database/dbplugin/plugin.go index 61de0fe8ce05..9a6691fbabad 100644 --- a/builtin/logical/database/dbplugin/plugin.go +++ b/builtin/logical/database/dbplugin/plugin.go @@ -33,15 +33,32 @@ type Statements struct { // object in a logging and metrics middleware. func PluginFactory(pluginName string, sys pluginutil.LookWrapper, logger log.Logger) (DatabaseType, error) { // Look for plugin in the plugin catalog - pluginMeta, err := sys.LookupPlugin(pluginName) + pluginRunner, err := sys.LookupPlugin(pluginName) if err != nil { return nil, err } - // create a DatabasePluginClient instance - db, err := newPluginClient(sys, pluginMeta) - if err != nil { - return nil, err + var db DatabaseType + if pluginRunner.Builtin { + // Plugin is builtin so we can retrieve an instance of the interface + // from the pluginRunner. Then cast it to a DatabaseType. + dbRaw, err := pluginRunner.BuiltinFactory() + if err != nil { + return nil, fmt.Errorf("error getting plugin type: %s", err) + } + + var ok bool + db, ok = dbRaw.(DatabaseType) + if !ok { + return nil, fmt.Errorf("unsuported database type: %s", pluginName) + } + + } else { + // create a DatabasePluginClient instance + db, err = newPluginClient(sys, pluginRunner) + if err != nil { + return nil, err + } } typeStr, err := db.Type() diff --git a/cli/commands.go b/cli/commands.go index e7545ca906f3..13f7c8b25aad 100644 --- a/cli/commands.go +++ b/cli/commands.go @@ -331,11 +331,5 @@ func Commands(metaPtr *meta.Meta) map[string]cli.CommandFactory { Ui: metaPtr.Ui, }, nil }, - - "plugin-exec": func() (cli.Command, error) { - return &command.PluginExec{ - Meta: *metaPtr, - }, nil - }, } } diff --git a/command/plugin_exec.go b/command/plugin_exec.go deleted file mode 100644 index 575be14b7d91..000000000000 --- a/command/plugin_exec.go +++ /dev/null @@ -1,66 +0,0 @@ -package command - -import ( - "fmt" - "strings" - - "github.com/hashicorp/vault/helper/builtinplugins" - "github.com/hashicorp/vault/meta" -) - -type PluginExec struct { - meta.Meta -} - -func (c *PluginExec) Run(args []string) int { - flags := c.Meta.FlagSet("plugin-exec", meta.FlagSetDefault) - flags.Usage = func() { c.Ui.Error(c.Help()) } - if err := flags.Parse(args); err != nil { - return 1 - } - - args = flags.Args() - if len(args) != 1 { - flags.Usage() - c.Ui.Error(fmt.Sprintf( - "\nplugin-exec expects one argument: the plugin to execute.")) - return 1 - } - - pluginName := args[0] - - runner, ok := builtinplugins.BuiltinPlugins.Get(pluginName) - if !ok { - c.Ui.Error(fmt.Sprintf( - "No plugin with the name %s found", pluginName)) - return 1 - } - - err := runner() - if err != nil { - c.Ui.Error(fmt.Sprintf( - "Error running plugin: %s", err)) - return 1 - } - - return 0 -} - -func (c *PluginExec) Synopsis() string { - return "Runs a builtin plugin. Should only be called by vault." -} - -func (c *PluginExec) Help() string { - helpText := ` -Usage: vault plugin-exec type - - Runs a builtin plugin. Should only be called by vault. - - This will execute a plugin for use in a plugable location in vault. If run by - a cli user it will print a message indicating it can not be executed by anyone - other than vault. For supported plugin types see the vault documentation. - -General Options: -` + meta.GeneralOptionsUsage() - return strings.TrimSpace(helpText) -} diff --git a/command/server.go b/command/server.go index 6548aa58ce63..ef9e3e3a0c97 100644 --- a/command/server.go +++ b/command/server.go @@ -1,10 +1,8 @@ package command import ( - "crypto/sha256" "encoding/base64" "fmt" - "io" "net" "net/http" "net/url" @@ -133,33 +131,6 @@ func (c *ServerCommand) Run(args []string) int { dev = true } - // Record the vault binary's location and SHA-256 checksum for use in - // builtin plugins. - ex, err := os.Executable() - if err != nil { - c.Ui.Output(fmt.Sprintf( - "Error looking up vault binary: %s", err)) - return 1 - } - - file, err := os.Open(ex) - if err != nil { - c.Ui.Output(fmt.Sprintf( - "Error loading vault binary: %s", err)) - return 1 - } - defer file.Close() - - hash := sha256.New() - _, err = io.Copy(hash, file) - if err != nil { - c.Ui.Output(fmt.Sprintf( - "Error checksumming vault binary: %s", err)) - return 1 - } - - sha256Value := hash.Sum(nil) - // Validation if !dev { switch { @@ -254,23 +225,21 @@ func (c *ServerCommand) Run(args []string) int { } coreConfig := &vault.CoreConfig{ - Physical: backend, - RedirectAddr: config.Storage.RedirectAddr, - HAPhysical: nil, - Seal: seal, - AuditBackends: c.AuditBackends, - CredentialBackends: c.CredentialBackends, - LogicalBackends: c.LogicalBackends, - Logger: c.logger, - DisableCache: config.DisableCache, - DisableMlock: config.DisableMlock, - MaxLeaseTTL: config.MaxLeaseTTL, - DefaultLeaseTTL: config.DefaultLeaseTTL, - ClusterName: config.ClusterName, - CacheSize: config.CacheSize, - PluginDirectory: config.PluginDirectory, - VaultBinaryLocation: ex, - VaultBinarySHA256: sha256Value, + Physical: backend, + RedirectAddr: config.Storage.RedirectAddr, + HAPhysical: nil, + Seal: seal, + AuditBackends: c.AuditBackends, + CredentialBackends: c.CredentialBackends, + LogicalBackends: c.LogicalBackends, + Logger: c.logger, + DisableCache: config.DisableCache, + DisableMlock: config.DisableMlock, + MaxLeaseTTL: config.MaxLeaseTTL, + DefaultLeaseTTL: config.DefaultLeaseTTL, + ClusterName: config.ClusterName, + CacheSize: config.CacheSize, + PluginDirectory: config.PluginDirectory, } if dev { coreConfig.DevToken = devRootTokenID diff --git a/helper/builtinplugins/builtin.go b/helper/builtinplugins/builtin.go index beedbb15b857..9c51ae47898b 100644 --- a/helper/builtinplugins/builtin.go +++ b/helper/builtinplugins/builtin.go @@ -5,20 +5,22 @@ import ( "github.com/hashicorp/vault/plugins/database/postgresql" ) +type BuiltinFactory func() (interface{}, error) + var BuiltinPlugins *builtinPlugins = &builtinPlugins{ - plugins: map[string]func() error{ - "mysql-database-plugin": mysql.Run, - "postgresql-database-plugin": postgresql.Run, + plugins: map[string]BuiltinFactory{ + "mysql-database-plugin": mysql.New, + "postgresql-database-plugin": postgresql.New, }, } // The list of builtin plugins should not be changed by any other package, so we // store them in an unexported variable in this unexported struct. type builtinPlugins struct { - plugins map[string]func() error + plugins map[string]BuiltinFactory } -func (b *builtinPlugins) Get(name string) (func() error, bool) { +func (b *builtinPlugins) Get(name string) (BuiltinFactory, bool) { f, ok := b.plugins[name] return f, ok } diff --git a/helper/pluginutil/runner.go b/helper/pluginutil/runner.go index bbc5ab99b752..95de96a5a8c6 100644 --- a/helper/pluginutil/runner.go +++ b/helper/pluginutil/runner.go @@ -40,11 +40,12 @@ type LookWrapper interface { // PluginRunner defines the metadata needed to run a plugin securely with // go-plugin. type PluginRunner struct { - Name string `json:"name"` - Command string `json:"command"` - Args []string `json:"args"` - Sha256 []byte `json:"sha256"` - Builtin bool `json:"builtin"` + Name string `json:"name"` + Command string `json:"command"` + Args []string `json:"args"` + Sha256 []byte `json:"sha256"` + Builtin bool `json:"builtin"` + BuiltinFactory func() (interface{}, error) `json:"-"` } // Run takes a wrapper instance, and the go-plugin paramaters and executes a diff --git a/plugins/database/mysql/mysql.go b/plugins/database/mysql/mysql.go index ea14a6782b84..e7e2a8aea9d3 100644 --- a/plugins/database/mysql/mysql.go +++ b/plugins/database/mysql/mysql.go @@ -23,7 +23,7 @@ type MySQL struct { credsutil.CredentialsProducer } -func New() *MySQL { +func New() (interface{}, error) { connProducer := &connutil.SQLConnectionProducer{} connProducer.Type = mySQLTypeName @@ -37,14 +37,17 @@ func New() *MySQL { CredentialsProducer: credsProducer, } - return dbType + return dbType, nil } // Run instantiates a MySQL object, and runs the RPC server for the plugin func Run() error { - dbType := New() + dbType, err := New() + if err != nil { + return err + } - dbplugin.NewPluginServer(dbType) + dbplugin.NewPluginServer(dbType.(*MySQL)) return nil } diff --git a/plugins/database/mysql/mysql_test.go b/plugins/database/mysql/mysql_test.go index 2b1f27291861..c86f9c2f6b1d 100644 --- a/plugins/database/mysql/mysql_test.go +++ b/plugins/database/mysql/mysql_test.go @@ -66,7 +66,8 @@ func TestMySQL_Initialize(t *testing.T) { "connection_url": connURL, } - db := New() + dbRaw, _ := New() + db := dbRaw.(*MySQL) connProducer := db.ConnectionProducer.(*connutil.SQLConnectionProducer) err := db.Initialize(connectionDetails, true) @@ -92,7 +93,8 @@ func TestMySQL_CreateUser(t *testing.T) { "connection_url": connURL, } - db := New() + dbRaw, _ := New() + db := dbRaw.(*MySQL) err := db.Initialize(connectionDetails, true) if err != nil { @@ -127,7 +129,8 @@ func TestMySQL_RevokeUser(t *testing.T) { "connection_url": connURL, } - db := New() + dbRaw, _ := New() + db := dbRaw.(*MySQL) err := db.Initialize(connectionDetails, true) if err != nil { diff --git a/plugins/database/postgresql/postgresql.go b/plugins/database/postgresql/postgresql.go index b8449f54986c..5781b6c3d166 100644 --- a/plugins/database/postgresql/postgresql.go +++ b/plugins/database/postgresql/postgresql.go @@ -16,7 +16,8 @@ import ( const postgreSQLTypeName string = "postgres" -func New() *PostgreSQL { +// New implements builtinplugins.BuiltinFactory +func New() (interface{}, error) { connProducer := &connutil.SQLConnectionProducer{} connProducer.Type = postgreSQLTypeName @@ -30,14 +31,17 @@ func New() *PostgreSQL { CredentialsProducer: credsProducer, } - return dbType + return dbType, nil } // Run instatiates a PostgreSQL object, and runs the RPC server for the plugin func Run() error { - dbType := New() + dbType, err := New() + if err != nil { + return err + } - dbplugin.NewPluginServer(dbType) + dbplugin.NewPluginServer(dbType.(*PostgreSQL)) return nil } diff --git a/plugins/database/postgresql/postgresql_test.go b/plugins/database/postgresql/postgresql_test.go index c7ccc8ee8f13..79391dc56ec0 100644 --- a/plugins/database/postgresql/postgresql_test.go +++ b/plugins/database/postgresql/postgresql_test.go @@ -66,7 +66,9 @@ func TestPostgreSQL_Initialize(t *testing.T) { "connection_url": connURL, } - db := New() + dbRaw, _ := New() + db := dbRaw.(*PostgreSQL) + connProducer := db.ConnectionProducer.(*connutil.SQLConnectionProducer) err := db.Initialize(connectionDetails, true) @@ -92,7 +94,8 @@ func TestPostgreSQL_CreateUser(t *testing.T) { "connection_url": connURL, } - db := New() + dbRaw, _ := New() + db := dbRaw.(*PostgreSQL) err := db.Initialize(connectionDetails, true) if err != nil { t.Fatalf("err: %s", err) @@ -136,7 +139,8 @@ func TestPostgreSQL_RenewUser(t *testing.T) { "connection_url": connURL, } - db := New() + dbRaw, _ := New() + db := dbRaw.(*PostgreSQL) err := db.Initialize(connectionDetails, true) if err != nil { t.Fatalf("err: %s", err) @@ -176,7 +180,8 @@ func TestPostgreSQL_RevokeUser(t *testing.T) { "connection_url": connURL, } - db := New() + dbRaw, _ := New() + db := dbRaw.(*PostgreSQL) err := db.Initialize(connectionDetails, true) if err != nil { t.Fatalf("err: %s", err) diff --git a/vault/core.go b/vault/core.go index ef99741bfd5f..01ab49f752e3 100644 --- a/vault/core.go +++ b/vault/core.go @@ -335,12 +335,6 @@ type Core struct { // pluginDirectory is the location vault will look for plugin binaries pluginDirectory string - // vaultBinaryLocation is used to run builtin plugins in secure mode - vaultBinaryLocation string - - // vaultBinarySHA256 is used to run builtin plugins in secure mode - vaultBinarySHA256 []byte - // pluginCatalog is used to manage plugin configurations pluginCatalog *PluginCatalog @@ -389,9 +383,7 @@ type CoreConfig struct { EnableUI bool `json:"ui" structs:"ui" mapstructure:"ui"` - PluginDirectory string `json:"plugin_directory" structs:"plugin_directory" mapstructure:"plugin_directory"` - VaultBinaryLocation string `json:"vault_binary_location" structs:"vault_binary_location" mapstructure:"vault_binary_location"` - VaultBinarySHA256 []byte `json:"vault_binary_sha256" structs:"vault_binary_sha256" mapstructure:"vault_binary_sha256"` + PluginDirectory string `json:"plugin_directory" structs:"plugin_directory" mapstructure:"plugin_directory"` ReloadFuncs *map[string][]ReloadFunc ReloadFuncsLock *sync.RWMutex @@ -449,8 +441,6 @@ func NewCore(conf *CoreConfig) (*Core, error) { clusterName: conf.ClusterName, clusterListenerShutdownCh: make(chan struct{}), clusterListenerShutdownSuccessCh: make(chan struct{}), - vaultBinaryLocation: conf.VaultBinaryLocation, - vaultBinarySHA256: conf.VaultBinarySHA256, disableMlock: conf.DisableMlock, } diff --git a/vault/plugin_catalog.go b/vault/plugin_catalog.go index b89224780c40..598a16fac553 100644 --- a/vault/plugin_catalog.go +++ b/vault/plugin_catalog.go @@ -23,20 +23,16 @@ var ( // to be registered to the catalog before they can be used in backends. Builtin // plugins are automatically detected and included in the catalog. type PluginCatalog struct { - catalogView *BarrierView - directory string - vaultCommand string - vaultSHA256 []byte + catalogView *BarrierView + directory string lock sync.RWMutex } func (c *Core) setupPluginCatalog() error { c.pluginCatalog = &PluginCatalog{ - catalogView: c.systemBarrierView.SubView(pluginCatalogPrefix), - directory: c.pluginDirectory, - vaultCommand: c.vaultBinaryLocation, - vaultSHA256: c.vaultBinarySHA256, + catalogView: c.systemBarrierView.SubView(pluginCatalogPrefix), + directory: c.pluginDirectory, } return nil @@ -64,17 +60,15 @@ func (c *PluginCatalog) Get(name string) (*pluginutil.PluginRunner, error) { } // Look for builtin plugins - if _, ok := builtinplugins.BuiltinPlugins.Get(name); !ok { - return nil, fmt.Errorf("no plugin found with name: %s", name) + if factory, ok := builtinplugins.BuiltinPlugins.Get(name); ok { + return &pluginutil.PluginRunner{ + Name: name, + Builtin: true, + BuiltinFactory: factory, + }, nil } - return &pluginutil.PluginRunner{ - Name: name, - Command: c.vaultCommand, - Args: []string{"plugin-exec", name}, - Sha256: c.vaultSHA256, - Builtin: true, - }, nil + return nil, fmt.Errorf("no plugin found with name: %s", name) } // Set registers a new external plugin with the catalog, or updates an existing From 9abc31ece7ea25605af8d6538d8561d15e3a5268 Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Fri, 21 Apr 2017 09:10:26 -0700 Subject: [PATCH 085/152] Fix tests --- plugins/database/mysql/mysql.go | 1 + vault/logical_system_test.go | 2 -- vault/plugin_catalog_test.go | 4 ---- 3 files changed, 1 insertion(+), 6 deletions(-) diff --git a/plugins/database/mysql/mysql.go b/plugins/database/mysql/mysql.go index e7e2a8aea9d3..6485aaa8625d 100644 --- a/plugins/database/mysql/mysql.go +++ b/plugins/database/mysql/mysql.go @@ -23,6 +23,7 @@ type MySQL struct { credsutil.CredentialsProducer } +// New implements builtinplugins.BuiltinFactory func New() (interface{}, error) { connProducer := &connutil.SQLConnectionProducer{} connProducer.Type = mySQLTypeName diff --git a/vault/logical_system_test.go b/vault/logical_system_test.go index 0785e07a1e08..9da4cbdec23a 100644 --- a/vault/logical_system_test.go +++ b/vault/logical_system_test.go @@ -1122,8 +1122,6 @@ func TestSystemBackend_PluginCatalog_CRUD(t *testing.T) { t.Fatalf("error: %v", err) } c.pluginCatalog.directory = sym - c.pluginCatalog.vaultCommand = "vault" - c.pluginCatalog.vaultSHA256 = []byte{'1'} req := logical.TestRequest(t, logical.ListOperation, "plugin-catalog/") resp, err := b.HandleRequest(req) diff --git a/vault/plugin_catalog_test.go b/vault/plugin_catalog_test.go index e78e7d963013..c33a890cdf75 100644 --- a/vault/plugin_catalog_test.go +++ b/vault/plugin_catalog_test.go @@ -21,8 +21,6 @@ func TestPluginCatalog_CRUD(t *testing.T) { t.Fatalf("error: %v", err) } core.pluginCatalog.directory = sym - core.pluginCatalog.vaultCommand = "vault" - core.pluginCatalog.vaultSHA256 = []byte{'1'} // Get builtin plugin p, err := core.pluginCatalog.Get("mysql-database-plugin") @@ -99,8 +97,6 @@ func TestPluginCatalog_List(t *testing.T) { t.Fatalf("error: %v", err) } core.pluginCatalog.directory = sym - core.pluginCatalog.vaultCommand = "vault" - core.pluginCatalog.vaultSHA256 = []byte{'1'} // Get builtin plugins and sort them builtinKeys := builtinplugins.BuiltinPlugins.Keys() From 3ceb7b69e1e28f831efe097b9f72c33f473007e5 Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Fri, 21 Apr 2017 10:24:34 -0700 Subject: [PATCH 086/152] Fix tests --- vault/logical_system_test.go | 13 +++++++++---- vault/plugin_catalog_test.go | 20 +++++++++++++++++--- 2 files changed, 26 insertions(+), 7 deletions(-) diff --git a/vault/logical_system_test.go b/vault/logical_system_test.go index 9da4cbdec23a..e9836946c3bf 100644 --- a/vault/logical_system_test.go +++ b/vault/logical_system_test.go @@ -1141,13 +1141,18 @@ func TestSystemBackend_PluginCatalog_CRUD(t *testing.T) { expectedBuiltin := &pluginutil.PluginRunner{ Name: "mysql-database-plugin", - Command: "vault", - Args: []string{"plugin-exec", "mysql-database-plugin"}, - Sha256: []byte{'1'}, Builtin: true, } + expectedBuiltin.BuiltinFactory, _ = builtinplugins.BuiltinPlugins.Get("mysql-database-plugin") + + p := resp.Data["plugin"].(*pluginutil.PluginRunner) + if &(p.BuiltinFactory) == &(expectedBuiltin.BuiltinFactory) { + t.Fatal("expected BuiltinFactory did not match actual") + } - if !reflect.DeepEqual(resp.Data["plugin"].(*pluginutil.PluginRunner), expectedBuiltin) { + expectedBuiltin.BuiltinFactory = nil + p.BuiltinFactory = nil + if !reflect.DeepEqual(p, expectedBuiltin) { t.Fatalf("expected did not match actual, got %#v\n expected %#v\n", resp.Data["plugin"].(*pluginutil.PluginRunner), expectedBuiltin) } diff --git a/vault/plugin_catalog_test.go b/vault/plugin_catalog_test.go index c33a890cdf75..57e864892bb0 100644 --- a/vault/plugin_catalog_test.go +++ b/vault/plugin_catalog_test.go @@ -30,12 +30,15 @@ func TestPluginCatalog_CRUD(t *testing.T) { expectedBuiltin := &pluginutil.PluginRunner{ Name: "mysql-database-plugin", - Command: "vault", - Args: []string{"plugin-exec", "mysql-database-plugin"}, - Sha256: []byte{'1'}, Builtin: true, } + expectedBuiltin.BuiltinFactory, _ = builtinplugins.BuiltinPlugins.Get("mysql-database-plugin") + if &(p.BuiltinFactory) == &(expectedBuiltin.BuiltinFactory) { + t.Fatal("expected BuiltinFactory did not match actual") + } + expectedBuiltin.BuiltinFactory = nil + p.BuiltinFactory = nil if !reflect.DeepEqual(p, expectedBuiltin) { t.Fatalf("expected did not match actual, got %#v\n expected %#v\n", p, expectedBuiltin) } @@ -83,6 +86,17 @@ func TestPluginCatalog_CRUD(t *testing.T) { t.Fatalf("unexpected error %v", err) } + expectedBuiltin = &pluginutil.PluginRunner{ + Name: "mysql-database-plugin", + Builtin: true, + } + expectedBuiltin.BuiltinFactory, _ = builtinplugins.BuiltinPlugins.Get("mysql-database-plugin") + + if &(p.BuiltinFactory) == &(expectedBuiltin.BuiltinFactory) { + t.Fatal("expected BuiltinFactory did not match actual") + } + expectedBuiltin.BuiltinFactory = nil + p.BuiltinFactory = nil if !reflect.DeepEqual(p, expectedBuiltin) { t.Fatalf("expected did not match actual, got %#v\n expected %#v\n", p, expectedBuiltin) } From c5d5abef1163cf0893fe6f6a6b05a680c0699b28 Mon Sep 17 00:00:00 2001 From: Calvin Leung Huang Date: Sun, 23 Apr 2017 09:02:57 +0800 Subject: [PATCH 087/152] Add cassandra plugin --- .../cassandra-database-plugin/main.go | 16 + plugins/database/cassandra/cassandra.go | 145 +++ plugins/database/cassandra/cassandra_test.go | 226 ++++ .../cassandra/test-fixtures/cassandra.yaml | 1146 +++++++++++++++++ plugins/helper/database/connutil/cassandra.go | 7 +- 5 files changed, 1537 insertions(+), 3 deletions(-) create mode 100644 plugins/database/cassandra/cassandra-database-plugin/main.go create mode 100644 plugins/database/cassandra/cassandra.go create mode 100644 plugins/database/cassandra/cassandra_test.go create mode 100644 plugins/database/cassandra/test-fixtures/cassandra.yaml diff --git a/plugins/database/cassandra/cassandra-database-plugin/main.go b/plugins/database/cassandra/cassandra-database-plugin/main.go new file mode 100644 index 000000000000..79f0e0dbe94e --- /dev/null +++ b/plugins/database/cassandra/cassandra-database-plugin/main.go @@ -0,0 +1,16 @@ +package main + +import ( + "fmt" + "os" + + "github.com/hashicorp/vault/plugins/database/cassandra" +) + +func main() { + err := cassandra.Run() + if err != nil { + fmt.Println(err) + os.Exit(1) + } +} diff --git a/plugins/database/cassandra/cassandra.go b/plugins/database/cassandra/cassandra.go new file mode 100644 index 000000000000..621d6e375282 --- /dev/null +++ b/plugins/database/cassandra/cassandra.go @@ -0,0 +1,145 @@ +package cassandra + +import ( + "fmt" + "strings" + "time" + + "github.com/gocql/gocql" + "github.com/hashicorp/vault/builtin/logical/database/dbplugin" + "github.com/hashicorp/vault/helper/strutil" + "github.com/hashicorp/vault/plugins/helper/database/connutil" + "github.com/hashicorp/vault/plugins/helper/database/credsutil" + "github.com/hashicorp/vault/plugins/helper/database/dbutil" +) + +const ( + defaultCreationCQL = `CREATE USER '{{username}}' WITH PASSWORD '{{password}}' NOSUPERUSER;` + defaultRollbackCQL = `DROP USER '{{username}}';` + cassandraTypeName = "cassandra" +) + +type Cassandra struct { + connutil.ConnectionProducer + credsutil.CredentialsProducer +} + +func New() *Cassandra { + connProducer := &connutil.CassandraConnectionProducer{} + connProducer.Type = cassandraTypeName + + credsProducer := &credsutil.CassandraCredentialsProducer{} + + dbType := &Cassandra{ + ConnectionProducer: connProducer, + CredentialsProducer: credsProducer, + } + + return dbType +} + +// Run instantiates a MySQL object, and runs the RPC server for the plugin +func Run() error { + dbType := New() + + dbplugin.NewPluginServer(dbType) + + return nil +} + +func (c *Cassandra) Type() (string, error) { + return cassandraTypeName, nil +} + +func (c *Cassandra) getConnection() (*gocql.Session, error) { + session, err := c.Connection() + if err != nil { + return nil, err + } + + return session.(*gocql.Session), nil +} + +// func (c *Cassandra) CreateUser(statements dbplugin.Statements, username, password, expiration string) error { +func (c *Cassandra) CreateUser(statements dbplugin.Statements, usernamePrefix string, expiration time.Time) (username string, password string, err error) { + // Grab the lock + c.Lock() + defer c.Unlock() + + // Get the connection + session, err := c.getConnection() + if err != nil { + return "", "", err + } + + creationCQL := statements.CreationStatements + if creationCQL == "" { + creationCQL = defaultCreationCQL + } + rollbackCQL := statements.RollbackStatements + if rollbackCQL == "" { + rollbackCQL = defaultRollbackCQL + } + + username, err = c.GenerateUsername(usernamePrefix) + if err != nil { + return "", "", err + } + + password, err = c.GeneratePassword() + if err != nil { + return "", "", err + } + + // Execute each query + for _, query := range strutil.ParseArbitraryStringSlice(creationCQL, ";") { + query = strings.TrimSpace(query) + if len(query) == 0 { + continue + } + + err = session.Query(dbutil.QueryHelper(query, map[string]string{ + "username": username, + "password": password, + })).Exec() + if err != nil { + for _, query := range strutil.ParseArbitraryStringSlice(rollbackCQL, ";") { + query = strings.TrimSpace(query) + if len(query) == 0 { + continue + } + + session.Query(dbutil.QueryHelper(query, map[string]string{ + "username": username, + "password": password, + })).Exec() + } + return "", "", err + } + } + + return username, password, nil +} + +func (c *Cassandra) RenewUser(statements dbplugin.Statements, username string, expiration time.Time) error { + // NOOP + return nil +} + +func (c *Cassandra) RevokeUser(statements dbplugin.Statements, username string) error { + // Grab the lock + c.Lock() + defer c.Unlock() + + session, err := c.getConnection() + if err != nil { + return err + } + + err = session.Query(fmt.Sprintf("DROP USER '%s'", username)).Exec() + if err != nil { + return fmt.Errorf("error removing user '%s': %s", username, err) + } + + return nil +} diff --git a/plugins/database/cassandra/cassandra_test.go b/plugins/database/cassandra/cassandra_test.go new file mode 100644 index 000000000000..b81c32710c45 --- /dev/null +++ b/plugins/database/cassandra/cassandra_test.go @@ -0,0 +1,226 @@ +package cassandra + +import ( + "os" + "strconv" + "testing" + "time" + + "fmt" + + "github.com/gocql/gocql" + "github.com/hashicorp/vault/builtin/logical/database/dbplugin" + "github.com/hashicorp/vault/plugins/helper/database/connutil" + dockertest "gopkg.in/ory-am/dockertest.v3" +) + +func prepareCassandraTestContainer(t *testing.T) (cleanup func(), retURL string) { + if os.Getenv("CASSANDRA_HOST") != "" { + return func() {}, os.Getenv("CASSANDRA_HOST") + } + + pool, err := dockertest.NewPool("") + if err != nil { + t.Fatalf("Failed to connect to docker: %s", err) + } + + cwd, _ := os.Getwd() + cassandraMountPath := fmt.Sprintf("%s/test-fixtures/:/etc/cassandra/", cwd) + + ro := &dockertest.RunOptions{ + Repository: "cassandra", + Tag: "latest", + Mounts: []string{cassandraMountPath}, + } + resource, err := pool.RunWithOptions(ro) + if err != nil { + t.Fatalf("Could not start local cassandra docker container: %s", err) + } + + cleanup = func() { + err := pool.Purge(resource) + if err != nil { + t.Fatalf("Failed to cleanup local container: %s", err) + } + } + + retURL = fmt.Sprintf("localhost:%s", resource.GetPort("9042/tcp")) + port, _ := strconv.Atoi(resource.GetPort("9042/tcp")) + + // exponential backoff-retry + if err = pool.Retry(func() error { + clusterConfig := gocql.NewCluster(retURL) + clusterConfig.Authenticator = gocql.PasswordAuthenticator{ + Username: "cassandra", + Password: "cassandra", + } + clusterConfig.ProtoVersion = 4 + clusterConfig.Port = port + + session, err := clusterConfig.CreateSession() + if err != nil { + return fmt.Errorf("error creating session: %s", err) + } + defer session.Close() + return nil + }); err != nil { + t.Fatalf("Could not connect to cassandra docker container: %s", err) + } + return +} + +func TestCassandra_Initialize(t *testing.T) { + cleanup, connURL := prepareCassandraTestContainer(t) + defer cleanup() + + connectionDetails := map[string]interface{}{ + "hosts": connURL, + "username": "cassandra", + "password": "cassandra", + "protocol_version": 4, + } + + db := New() + connProducer := db.ConnectionProducer.(*connutil.CassandraConnectionProducer) + + err := db.Initialize(connectionDetails, true) + if err != nil { + t.Fatalf("err: %s", err) + } + + if !connProducer.Initialized { + t.Fatal("Database should be initalized") + } + + err = db.Close() + if err != nil { + t.Fatalf("err: %s", err) + } +} + +func TestCassandra_CreateUser(t *testing.T) { + cleanup, connURL := prepareCassandraTestContainer(t) + defer cleanup() + + connectionDetails := map[string]interface{}{ + "hosts": connURL, + "username": "cassandra", + "password": "cassandra", + "protocol_version": 4, + } + + db := New() + err := db.Initialize(connectionDetails, true) + if err != nil { + t.Fatalf("err: %s", err) + } + + statements := dbplugin.Statements{ + CreationStatements: testCassandraRole, + } + + username, password, err := db.CreateUser(statements, "test", time.Now().Add(time.Minute)) + if err != nil { + t.Fatalf("err: %s", err) + } + + if err := testCredsExist(t, connURL, username, password); err != nil { + t.Fatalf("Could not connect with new credentials: %s", err) + } +} + +func TestMyCassandra_RenewUser(t *testing.T) { + cleanup, connURL := prepareCassandraTestContainer(t) + defer cleanup() + + connectionDetails := map[string]interface{}{ + "hosts": connURL, + "username": "cassandra", + "password": "cassandra", + "protocol_version": 4, + } + + db := New() + err := db.Initialize(connectionDetails, true) + if err != nil { + t.Fatalf("err: %s", err) + } + + statements := dbplugin.Statements{ + CreationStatements: testCassandraRole, + } + + username, password, err := db.CreateUser(statements, "test", time.Now().Add(time.Minute)) + if err != nil { + t.Fatalf("err: %s", err) + } + + if err := testCredsExist(t, connURL, username, password); err != nil { + t.Fatalf("Could not connect with new credentials: %s", err) + } + + err = db.RenewUser(statements, username, time.Now().Add(time.Minute)) + if err != nil { + t.Fatalf("err: %s", err) + } +} + +func TestCassandra_RevokeUser(t *testing.T) { + cleanup, connURL := prepareCassandraTestContainer(t) + defer cleanup() + + connectionDetails := map[string]interface{}{ + "hosts": connURL, + "username": "cassandra", + "password": "cassandra", + "protocol_version": 4, + } + + db := New() + err := db.Initialize(connectionDetails, true) + if err != nil { + t.Fatalf("err: %s", err) + } + + statements := dbplugin.Statements{ + CreationStatements: testCassandraRole, + } + + username, password, err := db.CreateUser(statements, "test", time.Now().Add(time.Minute)) + if err != nil { + t.Fatalf("err: %s", err) + } + + if err = testCredsExist(t, connURL, username, password); err != nil { + t.Fatalf("Could not connect with new credentials: %s", err) + } + + // Test default revoke statememts + err = db.RevokeUser(statements, username) + if err != nil { + t.Fatalf("err: %s", err) + } + + if err = testCredsExist(t, connURL, username, password); err == nil { + t.Fatal("Credentials were not revoked") + } +} + +func testCredsExist(t testing.TB, connURL, username, password string) error { + clusterConfig := gocql.NewCluster(connURL) + clusterConfig.Authenticator = gocql.PasswordAuthenticator{ + Username: username, + Password: password, + } + clusterConfig.ProtoVersion = 4 + + session, err := clusterConfig.CreateSession() + if err != nil { + return fmt.Errorf("error creating session: %s", err) + } + defer session.Close() + return nil +} + +const testCassandraRole = `CREATE USER '{{username}}' WITH PASSWORD '{{password}}' NOSUPERUSER; +GRANT ALL PERMISSIONS ON ALL KEYSPACES TO {{username}};` diff --git a/plugins/database/cassandra/test-fixtures/cassandra.yaml b/plugins/database/cassandra/test-fixtures/cassandra.yaml new file mode 100644 index 000000000000..5b12c8cf4e69 --- /dev/null +++ b/plugins/database/cassandra/test-fixtures/cassandra.yaml @@ -0,0 +1,1146 @@ +# Cassandra storage config YAML + +# NOTE: +# See http://wiki.apache.org/cassandra/StorageConfiguration for +# full explanations of configuration directives +# /NOTE + +# The name of the cluster. This is mainly used to prevent machines in +# one logical cluster from joining another. +cluster_name: 'Test Cluster' + +# This defines the number of tokens randomly assigned to this node on the ring +# The more tokens, relative to other nodes, the larger the proportion of data +# that this node will store. You probably want all nodes to have the same number +# of tokens assuming they have equal hardware capability. +# +# If you leave this unspecified, Cassandra will use the default of 1 token for legacy compatibility, +# and will use the initial_token as described below. +# +# Specifying initial_token will override this setting on the node's initial start, +# on subsequent starts, this setting will apply even if initial token is set. +# +# If you already have a cluster with 1 token per node, and wish to migrate to +# multiple tokens per node, see http://wiki.apache.org/cassandra/Operations +num_tokens: 256 + +# Triggers automatic allocation of num_tokens tokens for this node. The allocation +# algorithm attempts to choose tokens in a way that optimizes replicated load over +# the nodes in the datacenter for the replication strategy used by the specified +# keyspace. +# +# The load assigned to each node will be close to proportional to its number of +# vnodes. +# +# Only supported with the Murmur3Partitioner. +# allocate_tokens_for_keyspace: KEYSPACE + +# initial_token allows you to specify tokens manually. While you can use it with +# vnodes (num_tokens > 1, above) -- in which case you should provide a +# comma-separated list -- it's primarily used when adding nodes to legacy clusters +# that do not have vnodes enabled. +# initial_token: + +# See http://wiki.apache.org/cassandra/HintedHandoff +# May either be "true" or "false" to enable globally +hinted_handoff_enabled: true + +# When hinted_handoff_enabled is true, a black list of data centers that will not +# perform hinted handoff +# hinted_handoff_disabled_datacenters: +# - DC1 +# - DC2 + +# this defines the maximum amount of time a dead host will have hints +# generated. After it has been dead this long, new hints for it will not be +# created until it has been seen alive and gone down again. +max_hint_window_in_ms: 10800000 # 3 hours + +# Maximum throttle in KBs per second, per delivery thread. This will be +# reduced proportionally to the number of nodes in the cluster. (If there +# are two nodes in the cluster, each delivery thread will use the maximum +# rate; if there are three, each will throttle to half of the maximum, +# since we expect two nodes to be delivering hints simultaneously.) +hinted_handoff_throttle_in_kb: 1024 + +# Number of threads with which to deliver hints; +# Consider increasing this number when you have multi-dc deployments, since +# cross-dc handoff tends to be slower +max_hints_delivery_threads: 2 + +# Directory where Cassandra should store hints. +# If not set, the default directory is $CASSANDRA_HOME/data/hints. +# hints_directory: /var/lib/cassandra/hints + +# How often hints should be flushed from the internal buffers to disk. +# Will *not* trigger fsync. +hints_flush_period_in_ms: 10000 + +# Maximum size for a single hints file, in megabytes. +max_hints_file_size_in_mb: 128 + +# Compression to apply to the hint files. If omitted, hints files +# will be written uncompressed. LZ4, Snappy, and Deflate compressors +# are supported. +#hints_compression: +# - class_name: LZ4Compressor +# parameters: +# - + +# Maximum throttle in KBs per second, total. This will be +# reduced proportionally to the number of nodes in the cluster. +batchlog_replay_throttle_in_kb: 1024 + +# Authentication backend, implementing IAuthenticator; used to identify users +# Out of the box, Cassandra provides org.apache.cassandra.auth.{AllowAllAuthenticator, +# PasswordAuthenticator}. +# +# - AllowAllAuthenticator performs no checks - set it to disable authentication. +# - PasswordAuthenticator relies on username/password pairs to authenticate +# users. It keeps usernames and hashed passwords in system_auth.credentials table. +# Please increase system_auth keyspace replication factor if you use this authenticator. +# If using PasswordAuthenticator, CassandraRoleManager must also be used (see below) +authenticator: PasswordAuthenticator + +# Authorization backend, implementing IAuthorizer; used to limit access/provide permissions +# Out of the box, Cassandra provides org.apache.cassandra.auth.{AllowAllAuthorizer, +# CassandraAuthorizer}. +# +# - AllowAllAuthorizer allows any action to any user - set it to disable authorization. +# - CassandraAuthorizer stores permissions in system_auth.permissions table. Please +# increase system_auth keyspace replication factor if you use this authorizer. +authorizer: CassandraAuthorizer + +# Part of the Authentication & Authorization backend, implementing IRoleManager; used +# to maintain grants and memberships between roles. +# Out of the box, Cassandra provides org.apache.cassandra.auth.CassandraRoleManager, +# which stores role information in the system_auth keyspace. Most functions of the +# IRoleManager require an authenticated login, so unless the configured IAuthenticator +# actually implements authentication, most of this functionality will be unavailable. +# +# - CassandraRoleManager stores role data in the system_auth keyspace. Please +# increase system_auth keyspace replication factor if you use this role manager. +role_manager: CassandraRoleManager + +# Validity period for roles cache (fetching granted roles can be an expensive +# operation depending on the role manager, CassandraRoleManager is one example) +# Granted roles are cached for authenticated sessions in AuthenticatedUser and +# after the period specified here, become eligible for (async) reload. +# Defaults to 2000, set to 0 to disable caching entirely. +# Will be disabled automatically for AllowAllAuthenticator. +roles_validity_in_ms: 2000 + +# Refresh interval for roles cache (if enabled). +# After this interval, cache entries become eligible for refresh. Upon next +# access, an async reload is scheduled and the old value returned until it +# completes. If roles_validity_in_ms is non-zero, then this must be +# also. +# Defaults to the same value as roles_validity_in_ms. +# roles_update_interval_in_ms: 2000 + +# Validity period for permissions cache (fetching permissions can be an +# expensive operation depending on the authorizer, CassandraAuthorizer is +# one example). Defaults to 2000, set to 0 to disable. +# Will be disabled automatically for AllowAllAuthorizer. +permissions_validity_in_ms: 2000 + +# Refresh interval for permissions cache (if enabled). +# After this interval, cache entries become eligible for refresh. Upon next +# access, an async reload is scheduled and the old value returned until it +# completes. If permissions_validity_in_ms is non-zero, then this must be +# also. +# Defaults to the same value as permissions_validity_in_ms. +# permissions_update_interval_in_ms: 2000 + +# Validity period for credentials cache. This cache is tightly coupled to +# the provided PasswordAuthenticator implementation of IAuthenticator. If +# another IAuthenticator implementation is configured, this cache will not +# be automatically used and so the following settings will have no effect. +# Please note, credentials are cached in their encrypted form, so while +# activating this cache may reduce the number of queries made to the +# underlying table, it may not bring a significant reduction in the +# latency of individual authentication attempts. +# Defaults to 2000, set to 0 to disable credentials caching. +credentials_validity_in_ms: 2000 + +# Refresh interval for credentials cache (if enabled). +# After this interval, cache entries become eligible for refresh. Upon next +# access, an async reload is scheduled and the old value returned until it +# completes. If credentials_validity_in_ms is non-zero, then this must be +# also. +# Defaults to the same value as credentials_validity_in_ms. +# credentials_update_interval_in_ms: 2000 + +# The partitioner is responsible for distributing groups of rows (by +# partition key) across nodes in the cluster. You should leave this +# alone for new clusters. The partitioner can NOT be changed without +# reloading all data, so when upgrading you should set this to the +# same partitioner you were already using. +# +# Besides Murmur3Partitioner, partitioners included for backwards +# compatibility include RandomPartitioner, ByteOrderedPartitioner, and +# OrderPreservingPartitioner. +# +partitioner: org.apache.cassandra.dht.Murmur3Partitioner + +# Directories where Cassandra should store data on disk. Cassandra +# will spread data evenly across them, subject to the granularity of +# the configured compaction strategy. +# If not set, the default directory is $CASSANDRA_HOME/data/data. +data_file_directories: + - /var/lib/cassandra/data + +# commit log. when running on magnetic HDD, this should be a +# separate spindle than the data directories. +# If not set, the default directory is $CASSANDRA_HOME/data/commitlog. +commitlog_directory: /var/lib/cassandra/commitlog + +# Enable / disable CDC functionality on a per-node basis. This modifies the logic used +# for write path allocation rejection (standard: never reject. cdc: reject Mutation +# containing a CDC-enabled table if at space limit in cdc_raw_directory). +cdc_enabled: false + +# CommitLogSegments are moved to this directory on flush if cdc_enabled: true and the +# segment contains mutations for a CDC-enabled table. This should be placed on a +# separate spindle than the data directories. If not set, the default directory is +# $CASSANDRA_HOME/data/cdc_raw. +# cdc_raw_directory: /var/lib/cassandra/cdc_raw + +# Policy for data disk failures: +# +# die +# shut down gossip and client transports and kill the JVM for any fs errors or +# single-sstable errors, so the node can be replaced. +# +# stop_paranoid +# shut down gossip and client transports even for single-sstable errors, +# kill the JVM for errors during startup. +# +# stop +# shut down gossip and client transports, leaving the node effectively dead, but +# can still be inspected via JMX, kill the JVM for errors during startup. +# +# best_effort +# stop using the failed disk and respond to requests based on +# remaining available sstables. This means you WILL see obsolete +# data at CL.ONE! +# +# ignore +# ignore fatal errors and let requests fail, as in pre-1.2 Cassandra +disk_failure_policy: stop + +# Policy for commit disk failures: +# +# die +# shut down gossip and Thrift and kill the JVM, so the node can be replaced. +# +# stop +# shut down gossip and Thrift, leaving the node effectively dead, but +# can still be inspected via JMX. +# +# stop_commit +# shutdown the commit log, letting writes collect but +# continuing to service reads, as in pre-2.0.5 Cassandra +# +# ignore +# ignore fatal errors and let the batches fail +commit_failure_policy: stop + +# Maximum size of the native protocol prepared statement cache +# +# Valid values are either "auto" (omitting the value) or a value greater 0. +# +# Note that specifying a too large value will result in long running GCs and possbily +# out-of-memory errors. Keep the value at a small fraction of the heap. +# +# If you constantly see "prepared statements discarded in the last minute because +# cache limit reached" messages, the first step is to investigate the root cause +# of these messages and check whether prepared statements are used correctly - +# i.e. use bind markers for variable parts. +# +# Do only change the default value, if you really have more prepared statements than +# fit in the cache. In most cases it is not neccessary to change this value. +# Constantly re-preparing statements is a performance penalty. +# +# Default value ("auto") is 1/256th of the heap or 10MB, whichever is greater +prepared_statements_cache_size_mb: + +# Maximum size of the Thrift prepared statement cache +# +# If you do not use Thrift at all, it is safe to leave this value at "auto". +# +# See description of 'prepared_statements_cache_size_mb' above for more information. +# +# Default value ("auto") is 1/256th of the heap or 10MB, whichever is greater +thrift_prepared_statements_cache_size_mb: + +# Maximum size of the key cache in memory. +# +# Each key cache hit saves 1 seek and each row cache hit saves 2 seeks at the +# minimum, sometimes more. The key cache is fairly tiny for the amount of +# time it saves, so it's worthwhile to use it at large numbers. +# The row cache saves even more time, but must contain the entire row, +# so it is extremely space-intensive. It's best to only use the +# row cache if you have hot rows or static rows. +# +# NOTE: if you reduce the size, you may not get you hottest keys loaded on startup. +# +# Default value is empty to make it "auto" (min(5% of Heap (in MB), 100MB)). Set to 0 to disable key cache. +key_cache_size_in_mb: + +# Duration in seconds after which Cassandra should +# save the key cache. Caches are saved to saved_caches_directory as +# specified in this configuration file. +# +# Saved caches greatly improve cold-start speeds, and is relatively cheap in +# terms of I/O for the key cache. Row cache saving is much more expensive and +# has limited use. +# +# Default is 14400 or 4 hours. +key_cache_save_period: 14400 + +# Number of keys from the key cache to save +# Disabled by default, meaning all keys are going to be saved +# key_cache_keys_to_save: 100 + +# Row cache implementation class name. Available implementations: +# +# org.apache.cassandra.cache.OHCProvider +# Fully off-heap row cache implementation (default). +# +# org.apache.cassandra.cache.SerializingCacheProvider +# This is the row cache implementation availabile +# in previous releases of Cassandra. +# row_cache_class_name: org.apache.cassandra.cache.OHCProvider + +# Maximum size of the row cache in memory. +# Please note that OHC cache implementation requires some additional off-heap memory to manage +# the map structures and some in-flight memory during operations before/after cache entries can be +# accounted against the cache capacity. This overhead is usually small compared to the whole capacity. +# Do not specify more memory that the system can afford in the worst usual situation and leave some +# headroom for OS block level cache. Do never allow your system to swap. +# +# Default value is 0, to disable row caching. +row_cache_size_in_mb: 0 + +# Duration in seconds after which Cassandra should save the row cache. +# Caches are saved to saved_caches_directory as specified in this configuration file. +# +# Saved caches greatly improve cold-start speeds, and is relatively cheap in +# terms of I/O for the key cache. Row cache saving is much more expensive and +# has limited use. +# +# Default is 0 to disable saving the row cache. +row_cache_save_period: 0 + +# Number of keys from the row cache to save. +# Specify 0 (which is the default), meaning all keys are going to be saved +# row_cache_keys_to_save: 100 + +# Maximum size of the counter cache in memory. +# +# Counter cache helps to reduce counter locks' contention for hot counter cells. +# In case of RF = 1 a counter cache hit will cause Cassandra to skip the read before +# write entirely. With RF > 1 a counter cache hit will still help to reduce the duration +# of the lock hold, helping with hot counter cell updates, but will not allow skipping +# the read entirely. Only the local (clock, count) tuple of a counter cell is kept +# in memory, not the whole counter, so it's relatively cheap. +# +# NOTE: if you reduce the size, you may not get you hottest keys loaded on startup. +# +# Default value is empty to make it "auto" (min(2.5% of Heap (in MB), 50MB)). Set to 0 to disable counter cache. +# NOTE: if you perform counter deletes and rely on low gcgs, you should disable the counter cache. +counter_cache_size_in_mb: + +# Duration in seconds after which Cassandra should +# save the counter cache (keys only). Caches are saved to saved_caches_directory as +# specified in this configuration file. +# +# Default is 7200 or 2 hours. +counter_cache_save_period: 7200 + +# Number of keys from the counter cache to save +# Disabled by default, meaning all keys are going to be saved +# counter_cache_keys_to_save: 100 + +# saved caches +# If not set, the default directory is $CASSANDRA_HOME/data/saved_caches. +saved_caches_directory: /var/lib/cassandra/saved_caches + +# commitlog_sync may be either "periodic" or "batch." +# +# When in batch mode, Cassandra won't ack writes until the commit log +# has been fsynced to disk. It will wait +# commitlog_sync_batch_window_in_ms milliseconds between fsyncs. +# This window should be kept short because the writer threads will +# be unable to do extra work while waiting. (You may need to increase +# concurrent_writes for the same reason.) +# +# commitlog_sync: batch +# commitlog_sync_batch_window_in_ms: 2 +# +# the other option is "periodic" where writes may be acked immediately +# and the CommitLog is simply synced every commitlog_sync_period_in_ms +# milliseconds. +commitlog_sync: periodic +commitlog_sync_period_in_ms: 10000 + +# The size of the individual commitlog file segments. A commitlog +# segment may be archived, deleted, or recycled once all the data +# in it (potentially from each columnfamily in the system) has been +# flushed to sstables. +# +# The default size is 32, which is almost always fine, but if you are +# archiving commitlog segments (see commitlog_archiving.properties), +# then you probably want a finer granularity of archiving; 8 or 16 MB +# is reasonable. +# Max mutation size is also configurable via max_mutation_size_in_kb setting in +# cassandra.yaml. The default is half the size commitlog_segment_size_in_mb * 1024. +# +# NOTE: If max_mutation_size_in_kb is set explicitly then commitlog_segment_size_in_mb must +# be set to at least twice the size of max_mutation_size_in_kb / 1024 +# +commitlog_segment_size_in_mb: 32 + +# Compression to apply to the commit log. If omitted, the commit log +# will be written uncompressed. LZ4, Snappy, and Deflate compressors +# are supported. +# commitlog_compression: +# - class_name: LZ4Compressor +# parameters: +# - + +# any class that implements the SeedProvider interface and has a +# constructor that takes a Map of parameters will do. +seed_provider: + # Addresses of hosts that are deemed contact points. + # Cassandra nodes use this list of hosts to find each other and learn + # the topology of the ring. You must change this if you are running + # multiple nodes! + - class_name: org.apache.cassandra.locator.SimpleSeedProvider + parameters: + # seeds is actually a comma-delimited list of addresses. + # Ex: ",," + - seeds: "172.17.0.3" + +# For workloads with more data than can fit in memory, Cassandra's +# bottleneck will be reads that need to fetch data from +# disk. "concurrent_reads" should be set to (16 * number_of_drives) in +# order to allow the operations to enqueue low enough in the stack +# that the OS and drives can reorder them. Same applies to +# "concurrent_counter_writes", since counter writes read the current +# values before incrementing and writing them back. +# +# On the other hand, since writes are almost never IO bound, the ideal +# number of "concurrent_writes" is dependent on the number of cores in +# your system; (8 * number_of_cores) is a good rule of thumb. +concurrent_reads: 32 +concurrent_writes: 32 +concurrent_counter_writes: 32 + +# For materialized view writes, as there is a read involved, so this should +# be limited by the less of concurrent reads or concurrent writes. +concurrent_materialized_view_writes: 32 + +# Maximum memory to use for sstable chunk cache and buffer pooling. +# 32MB of this are reserved for pooling buffers, the rest is used as an +# cache that holds uncompressed sstable chunks. +# Defaults to the smaller of 1/4 of heap or 512MB. This pool is allocated off-heap, +# so is in addition to the memory allocated for heap. The cache also has on-heap +# overhead which is roughly 128 bytes per chunk (i.e. 0.2% of the reserved size +# if the default 64k chunk size is used). +# Memory is only allocated when needed. +# file_cache_size_in_mb: 512 + +# Flag indicating whether to allocate on or off heap when the sstable buffer +# pool is exhausted, that is when it has exceeded the maximum memory +# file_cache_size_in_mb, beyond which it will not cache buffers but allocate on request. + +# buffer_pool_use_heap_if_exhausted: true + +# The strategy for optimizing disk read +# Possible values are: +# ssd (for solid state disks, the default) +# spinning (for spinning disks) +# disk_optimization_strategy: ssd + +# Total permitted memory to use for memtables. Cassandra will stop +# accepting writes when the limit is exceeded until a flush completes, +# and will trigger a flush based on memtable_cleanup_threshold +# If omitted, Cassandra will set both to 1/4 the size of the heap. +# memtable_heap_space_in_mb: 2048 +# memtable_offheap_space_in_mb: 2048 + +# Ratio of occupied non-flushing memtable size to total permitted size +# that will trigger a flush of the largest memtable. Larger mct will +# mean larger flushes and hence less compaction, but also less concurrent +# flush activity which can make it difficult to keep your disks fed +# under heavy write load. +# +# memtable_cleanup_threshold defaults to 1 / (memtable_flush_writers + 1) +# memtable_cleanup_threshold: 0.11 + +# Specify the way Cassandra allocates and manages memtable memory. +# Options are: +# +# heap_buffers +# on heap nio buffers +# +# offheap_buffers +# off heap (direct) nio buffers +# +# offheap_objects +# off heap objects +memtable_allocation_type: heap_buffers + +# Total space to use for commit logs on disk. +# +# If space gets above this value, Cassandra will flush every dirty CF +# in the oldest segment and remove it. So a small total commitlog space +# will tend to cause more flush activity on less-active columnfamilies. +# +# The default value is the smaller of 8192, and 1/4 of the total space +# of the commitlog volume. +# +# commitlog_total_space_in_mb: 8192 + +# This sets the amount of memtable flush writer threads. These will +# be blocked by disk io, and each one will hold a memtable in memory +# while blocked. +# +# memtable_flush_writers defaults to one per data_file_directory. +# +# If your data directories are backed by SSD, you can increase this, but +# avoid having memtable_flush_writers * data_file_directories > number of cores +#memtable_flush_writers: 1 + +# Total space to use for change-data-capture logs on disk. +# +# If space gets above this value, Cassandra will throw WriteTimeoutException +# on Mutations including tables with CDC enabled. A CDCCompactor is responsible +# for parsing the raw CDC logs and deleting them when parsing is completed. +# +# The default value is the min of 4096 mb and 1/8th of the total space +# of the drive where cdc_raw_directory resides. +# cdc_total_space_in_mb: 4096 + +# When we hit our cdc_raw limit and the CDCCompactor is either running behind +# or experiencing backpressure, we check at the following interval to see if any +# new space for cdc-tracked tables has been made available. Default to 250ms +# cdc_free_space_check_interval_ms: 250 + +# A fixed memory pool size in MB for for SSTable index summaries. If left +# empty, this will default to 5% of the heap size. If the memory usage of +# all index summaries exceeds this limit, SSTables with low read rates will +# shrink their index summaries in order to meet this limit. However, this +# is a best-effort process. In extreme conditions Cassandra may need to use +# more than this amount of memory. +index_summary_capacity_in_mb: + +# How frequently index summaries should be resampled. This is done +# periodically to redistribute memory from the fixed-size pool to sstables +# proportional their recent read rates. Setting to -1 will disable this +# process, leaving existing index summaries at their current sampling level. +index_summary_resize_interval_in_minutes: 60 + +# Whether to, when doing sequential writing, fsync() at intervals in +# order to force the operating system to flush the dirty +# buffers. Enable this to avoid sudden dirty buffer flushing from +# impacting read latencies. Almost always a good idea on SSDs; not +# necessarily on platters. +trickle_fsync: false +trickle_fsync_interval_in_kb: 10240 + +# TCP port, for commands and data +# For security reasons, you should not expose this port to the internet. Firewall it if needed. +storage_port: 7000 + +# SSL port, for encrypted communication. Unused unless enabled in +# encryption_options +# For security reasons, you should not expose this port to the internet. Firewall it if needed. +ssl_storage_port: 7001 + +# Address or interface to bind to and tell other Cassandra nodes to connect to. +# You _must_ change this if you want multiple nodes to be able to communicate! +# +# Set listen_address OR listen_interface, not both. +# +# Leaving it blank leaves it up to InetAddress.getLocalHost(). This +# will always do the Right Thing _if_ the node is properly configured +# (hostname, name resolution, etc), and the Right Thing is to use the +# address associated with the hostname (it might not be). +# +# Setting listen_address to 0.0.0.0 is always wrong. +# +listen_address: 172.17.0.3 + +# Set listen_address OR listen_interface, not both. Interfaces must correspond +# to a single address, IP aliasing is not supported. +# listen_interface: eth0 + +# If you choose to specify the interface by name and the interface has an ipv4 and an ipv6 address +# you can specify which should be chosen using listen_interface_prefer_ipv6. If false the first ipv4 +# address will be used. If true the first ipv6 address will be used. Defaults to false preferring +# ipv4. If there is only one address it will be selected regardless of ipv4/ipv6. +# listen_interface_prefer_ipv6: false + +# Address to broadcast to other Cassandra nodes +# Leaving this blank will set it to the same value as listen_address +broadcast_address: 172.17.0.3 + +# When using multiple physical network interfaces, set this +# to true to listen on broadcast_address in addition to +# the listen_address, allowing nodes to communicate in both +# interfaces. +# Ignore this property if the network configuration automatically +# routes between the public and private networks such as EC2. +# listen_on_broadcast_address: false + +# Internode authentication backend, implementing IInternodeAuthenticator; +# used to allow/disallow connections from peer nodes. +# internode_authenticator: org.apache.cassandra.auth.AllowAllInternodeAuthenticator + +# Whether to start the native transport server. +# Please note that the address on which the native transport is bound is the +# same as the rpc_address. The port however is different and specified below. +start_native_transport: true +# port for the CQL native transport to listen for clients on +# For security reasons, you should not expose this port to the internet. Firewall it if needed. +native_transport_port: 9042 +# Enabling native transport encryption in client_encryption_options allows you to either use +# encryption for the standard port or to use a dedicated, additional port along with the unencrypted +# standard native_transport_port. +# Enabling client encryption and keeping native_transport_port_ssl disabled will use encryption +# for native_transport_port. Setting native_transport_port_ssl to a different value +# from native_transport_port will use encryption for native_transport_port_ssl while +# keeping native_transport_port unencrypted. +# native_transport_port_ssl: 9142 +# The maximum threads for handling requests when the native transport is used. +# This is similar to rpc_max_threads though the default differs slightly (and +# there is no native_transport_min_threads, idle threads will always be stopped +# after 30 seconds). +# native_transport_max_threads: 128 +# +# The maximum size of allowed frame. Frame (requests) larger than this will +# be rejected as invalid. The default is 256MB. If you're changing this parameter, +# you may want to adjust max_value_size_in_mb accordingly. +# native_transport_max_frame_size_in_mb: 256 + +# The maximum number of concurrent client connections. +# The default is -1, which means unlimited. +# native_transport_max_concurrent_connections: -1 + +# The maximum number of concurrent client connections per source ip. +# The default is -1, which means unlimited. +# native_transport_max_concurrent_connections_per_ip: -1 + +# Whether to start the thrift rpc server. +start_rpc: false + +# The address or interface to bind the Thrift RPC service and native transport +# server to. +# +# Set rpc_address OR rpc_interface, not both. +# +# Leaving rpc_address blank has the same effect as on listen_address +# (i.e. it will be based on the configured hostname of the node). +# +# Note that unlike listen_address, you can specify 0.0.0.0, but you must also +# set broadcast_rpc_address to a value other than 0.0.0.0. +# +# For security reasons, you should not expose this port to the internet. Firewall it if needed. +rpc_address: 0.0.0.0 + +# Set rpc_address OR rpc_interface, not both. Interfaces must correspond +# to a single address, IP aliasing is not supported. +# rpc_interface: eth1 + +# If you choose to specify the interface by name and the interface has an ipv4 and an ipv6 address +# you can specify which should be chosen using rpc_interface_prefer_ipv6. If false the first ipv4 +# address will be used. If true the first ipv6 address will be used. Defaults to false preferring +# ipv4. If there is only one address it will be selected regardless of ipv4/ipv6. +# rpc_interface_prefer_ipv6: false + +# port for Thrift to listen for clients on +rpc_port: 9160 + +# RPC address to broadcast to drivers and other Cassandra nodes. This cannot +# be set to 0.0.0.0. If left blank, this will be set to the value of +# rpc_address. If rpc_address is set to 0.0.0.0, broadcast_rpc_address must +# be set. +broadcast_rpc_address: 172.17.0.3 + +# enable or disable keepalive on rpc/native connections +rpc_keepalive: true + +# Cassandra provides two out-of-the-box options for the RPC Server: +# +# sync +# One thread per thrift connection. For a very large number of clients, memory +# will be your limiting factor. On a 64 bit JVM, 180KB is the minimum stack size +# per thread, and that will correspond to your use of virtual memory (but physical memory +# may be limited depending on use of stack space). +# +# hsha +# Stands for "half synchronous, half asynchronous." All thrift clients are handled +# asynchronously using a small number of threads that does not vary with the amount +# of thrift clients (and thus scales well to many clients). The rpc requests are still +# synchronous (one thread per active request). If hsha is selected then it is essential +# that rpc_max_threads is changed from the default value of unlimited. +# +# The default is sync because on Windows hsha is about 30% slower. On Linux, +# sync/hsha performance is about the same, with hsha of course using less memory. +# +# Alternatively, can provide your own RPC server by providing the fully-qualified class name +# of an o.a.c.t.TServerFactory that can create an instance of it. +rpc_server_type: sync + +# Uncomment rpc_min|max_thread to set request pool size limits. +# +# Regardless of your choice of RPC server (see above), the number of maximum requests in the +# RPC thread pool dictates how many concurrent requests are possible (but if you are using the sync +# RPC server, it also dictates the number of clients that can be connected at all). +# +# The default is unlimited and thus provides no protection against clients overwhelming the server. You are +# encouraged to set a maximum that makes sense for you in production, but do keep in mind that +# rpc_max_threads represents the maximum number of client requests this server may execute concurrently. +# +# rpc_min_threads: 16 +# rpc_max_threads: 2048 + +# uncomment to set socket buffer sizes on rpc connections +# rpc_send_buff_size_in_bytes: +# rpc_recv_buff_size_in_bytes: + +# Uncomment to set socket buffer size for internode communication +# Note that when setting this, the buffer size is limited by net.core.wmem_max +# and when not setting it it is defined by net.ipv4.tcp_wmem +# See also: +# /proc/sys/net/core/wmem_max +# /proc/sys/net/core/rmem_max +# /proc/sys/net/ipv4/tcp_wmem +# /proc/sys/net/ipv4/tcp_wmem +# and 'man tcp' +# internode_send_buff_size_in_bytes: + +# Uncomment to set socket buffer size for internode communication +# Note that when setting this, the buffer size is limited by net.core.wmem_max +# and when not setting it it is defined by net.ipv4.tcp_wmem +# internode_recv_buff_size_in_bytes: + +# Frame size for thrift (maximum message length). +thrift_framed_transport_size_in_mb: 15 + +# Set to true to have Cassandra create a hard link to each sstable +# flushed or streamed locally in a backups/ subdirectory of the +# keyspace data. Removing these links is the operator's +# responsibility. +incremental_backups: false + +# Whether or not to take a snapshot before each compaction. Be +# careful using this option, since Cassandra won't clean up the +# snapshots for you. Mostly useful if you're paranoid when there +# is a data format change. +snapshot_before_compaction: false + +# Whether or not a snapshot is taken of the data before keyspace truncation +# or dropping of column families. The STRONGLY advised default of true +# should be used to provide data safety. If you set this flag to false, you will +# lose data on truncation or drop. +auto_snapshot: true + +# Granularity of the collation index of rows within a partition. +# Increase if your rows are large, or if you have a very large +# number of rows per partition. The competing goals are these: +# +# - a smaller granularity means more index entries are generated +# and looking up rows withing the partition by collation column +# is faster +# - but, Cassandra will keep the collation index in memory for hot +# rows (as part of the key cache), so a larger granularity means +# you can cache more hot rows +column_index_size_in_kb: 64 + +# Per sstable indexed key cache entries (the collation index in memory +# mentioned above) exceeding this size will not be held on heap. +# This means that only partition information is held on heap and the +# index entries are read from disk. +# +# Note that this size refers to the size of the +# serialized index information and not the size of the partition. +column_index_cache_size_in_kb: 2 + +# Number of simultaneous compactions to allow, NOT including +# validation "compactions" for anti-entropy repair. Simultaneous +# compactions can help preserve read performance in a mixed read/write +# workload, by mitigating the tendency of small sstables to accumulate +# during a single long running compactions. The default is usually +# fine and if you experience problems with compaction running too +# slowly or too fast, you should look at +# compaction_throughput_mb_per_sec first. +# +# concurrent_compactors defaults to the smaller of (number of disks, +# number of cores), with a minimum of 2 and a maximum of 8. +# +# If your data directories are backed by SSD, you should increase this +# to the number of cores. +#concurrent_compactors: 1 + +# Throttles compaction to the given total throughput across the entire +# system. The faster you insert data, the faster you need to compact in +# order to keep the sstable count down, but in general, setting this to +# 16 to 32 times the rate you are inserting data is more than sufficient. +# Setting this to 0 disables throttling. Note that this account for all types +# of compaction, including validation compaction. +compaction_throughput_mb_per_sec: 16 + +# When compacting, the replacement sstable(s) can be opened before they +# are completely written, and used in place of the prior sstables for +# any range that has been written. This helps to smoothly transfer reads +# between the sstables, reducing page cache churn and keeping hot rows hot +sstable_preemptive_open_interval_in_mb: 50 + +# Throttles all outbound streaming file transfers on this node to the +# given total throughput in Mbps. This is necessary because Cassandra does +# mostly sequential IO when streaming data during bootstrap or repair, which +# can lead to saturating the network connection and degrading rpc performance. +# When unset, the default is 200 Mbps or 25 MB/s. +# stream_throughput_outbound_megabits_per_sec: 200 + +# Throttles all streaming file transfer between the datacenters, +# this setting allows users to throttle inter dc stream throughput in addition +# to throttling all network stream traffic as configured with +# stream_throughput_outbound_megabits_per_sec +# When unset, the default is 200 Mbps or 25 MB/s +# inter_dc_stream_throughput_outbound_megabits_per_sec: 200 + +# How long the coordinator should wait for read operations to complete +read_request_timeout_in_ms: 5000 +# How long the coordinator should wait for seq or index scans to complete +range_request_timeout_in_ms: 10000 +# How long the coordinator should wait for writes to complete +write_request_timeout_in_ms: 2000 +# How long the coordinator should wait for counter writes to complete +counter_write_request_timeout_in_ms: 5000 +# How long a coordinator should continue to retry a CAS operation +# that contends with other proposals for the same row +cas_contention_timeout_in_ms: 1000 +# How long the coordinator should wait for truncates to complete +# (This can be much longer, because unless auto_snapshot is disabled +# we need to flush first so we can snapshot before removing the data.) +truncate_request_timeout_in_ms: 60000 +# The default timeout for other, miscellaneous operations +request_timeout_in_ms: 10000 + +# Enable operation timeout information exchange between nodes to accurately +# measure request timeouts. If disabled, replicas will assume that requests +# were forwarded to them instantly by the coordinator, which means that +# under overload conditions we will waste that much extra time processing +# already-timed-out requests. +# +# Warning: before enabling this property make sure to ntp is installed +# and the times are synchronized between the nodes. +cross_node_timeout: false + +# Set socket timeout for streaming operation. +# The stream session is failed if no data/ack is received by any of the participants +# within that period, which means this should also be sufficient to stream a large +# sstable or rebuild table indexes. +# Default value is 86400000ms, which means stale streams timeout after 24 hours. +# A value of zero means stream sockets should never time out. +# streaming_socket_timeout_in_ms: 86400000 + +# phi value that must be reached for a host to be marked down. +# most users should never need to adjust this. +# phi_convict_threshold: 8 + +# endpoint_snitch -- Set this to a class that implements +# IEndpointSnitch. The snitch has two functions: +# +# - it teaches Cassandra enough about your network topology to route +# requests efficiently +# - it allows Cassandra to spread replicas around your cluster to avoid +# correlated failures. It does this by grouping machines into +# "datacenters" and "racks." Cassandra will do its best not to have +# more than one replica on the same "rack" (which may not actually +# be a physical location) +# +# CASSANDRA WILL NOT ALLOW YOU TO SWITCH TO AN INCOMPATIBLE SNITCH +# ONCE DATA IS INSERTED INTO THE CLUSTER. This would cause data loss. +# This means that if you start with the default SimpleSnitch, which +# locates every node on "rack1" in "datacenter1", your only options +# if you need to add another datacenter are GossipingPropertyFileSnitch +# (and the older PFS). From there, if you want to migrate to an +# incompatible snitch like Ec2Snitch you can do it by adding new nodes +# under Ec2Snitch (which will locate them in a new "datacenter") and +# decommissioning the old ones. +# +# Out of the box, Cassandra provides: +# +# SimpleSnitch: +# Treats Strategy order as proximity. This can improve cache +# locality when disabling read repair. Only appropriate for +# single-datacenter deployments. +# +# GossipingPropertyFileSnitch +# This should be your go-to snitch for production use. The rack +# and datacenter for the local node are defined in +# cassandra-rackdc.properties and propagated to other nodes via +# gossip. If cassandra-topology.properties exists, it is used as a +# fallback, allowing migration from the PropertyFileSnitch. +# +# PropertyFileSnitch: +# Proximity is determined by rack and data center, which are +# explicitly configured in cassandra-topology.properties. +# +# Ec2Snitch: +# Appropriate for EC2 deployments in a single Region. Loads Region +# and Availability Zone information from the EC2 API. The Region is +# treated as the datacenter, and the Availability Zone as the rack. +# Only private IPs are used, so this will not work across multiple +# Regions. +# +# Ec2MultiRegionSnitch: +# Uses public IPs as broadcast_address to allow cross-region +# connectivity. (Thus, you should set seed addresses to the public +# IP as well.) You will need to open the storage_port or +# ssl_storage_port on the public IP firewall. (For intra-Region +# traffic, Cassandra will switch to the private IP after +# establishing a connection.) +# +# RackInferringSnitch: +# Proximity is determined by rack and data center, which are +# assumed to correspond to the 3rd and 2nd octet of each node's IP +# address, respectively. Unless this happens to match your +# deployment conventions, this is best used as an example of +# writing a custom Snitch class and is provided in that spirit. +# +# You can use a custom Snitch by setting this to the full class name +# of the snitch, which will be assumed to be on your classpath. +endpoint_snitch: SimpleSnitch + +# controls how often to perform the more expensive part of host score +# calculation +dynamic_snitch_update_interval_in_ms: 100 +# controls how often to reset all host scores, allowing a bad host to +# possibly recover +dynamic_snitch_reset_interval_in_ms: 600000 +# if set greater than zero and read_repair_chance is < 1.0, this will allow +# 'pinning' of replicas to hosts in order to increase cache capacity. +# The badness threshold will control how much worse the pinned host has to be +# before the dynamic snitch will prefer other replicas over it. This is +# expressed as a double which represents a percentage. Thus, a value of +# 0.2 means Cassandra would continue to prefer the static snitch values +# until the pinned host was 20% worse than the fastest. +dynamic_snitch_badness_threshold: 0.1 + +# request_scheduler -- Set this to a class that implements +# RequestScheduler, which will schedule incoming client requests +# according to the specific policy. This is useful for multi-tenancy +# with a single Cassandra cluster. +# NOTE: This is specifically for requests from the client and does +# not affect inter node communication. +# org.apache.cassandra.scheduler.NoScheduler - No scheduling takes place +# org.apache.cassandra.scheduler.RoundRobinScheduler - Round robin of +# client requests to a node with a separate queue for each +# request_scheduler_id. The scheduler is further customized by +# request_scheduler_options as described below. +request_scheduler: org.apache.cassandra.scheduler.NoScheduler + +# Scheduler Options vary based on the type of scheduler +# +# NoScheduler +# Has no options +# +# RoundRobin +# throttle_limit +# The throttle_limit is the number of in-flight +# requests per client. Requests beyond +# that limit are queued up until +# running requests can complete. +# The value of 80 here is twice the number of +# concurrent_reads + concurrent_writes. +# default_weight +# default_weight is optional and allows for +# overriding the default which is 1. +# weights +# Weights are optional and will default to 1 or the +# overridden default_weight. The weight translates into how +# many requests are handled during each turn of the +# RoundRobin, based on the scheduler id. +# +# request_scheduler_options: +# throttle_limit: 80 +# default_weight: 5 +# weights: +# Keyspace1: 1 +# Keyspace2: 5 + +# request_scheduler_id -- An identifier based on which to perform +# the request scheduling. Currently the only valid option is keyspace. +# request_scheduler_id: keyspace + +# Enable or disable inter-node encryption +# JVM defaults for supported SSL socket protocols and cipher suites can +# be replaced using custom encryption options. This is not recommended +# unless you have policies in place that dictate certain settings, or +# need to disable vulnerable ciphers or protocols in case the JVM cannot +# be updated. +# FIPS compliant settings can be configured at JVM level and should not +# involve changing encryption settings here: +# https://docs.oracle.com/javase/8/docs/technotes/guides/security/jsse/FIPS.html +# *NOTE* No custom encryption options are enabled at the moment +# The available internode options are : all, none, dc, rack +# +# If set to dc cassandra will encrypt the traffic between the DCs +# If set to rack cassandra will encrypt the traffic between the racks +# +# The passwords used in these options must match the passwords used when generating +# the keystore and truststore. For instructions on generating these files, see: +# http://download.oracle.com/javase/6/docs/technotes/guides/security/jsse/JSSERefGuide.html#CreateKeystore +# +server_encryption_options: + internode_encryption: none + keystore: conf/.keystore + keystore_password: cassandra + truststore: conf/.truststore + truststore_password: cassandra + # More advanced defaults below: + # protocol: TLS + # algorithm: SunX509 + # store_type: JKS + # cipher_suites: [TLS_RSA_WITH_AES_128_CBC_SHA,TLS_RSA_WITH_AES_256_CBC_SHA,TLS_DHE_RSA_WITH_AES_128_CBC_SHA,TLS_DHE_RSA_WITH_AES_256_CBC_SHA,TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA,TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA] + # require_client_auth: false + # require_endpoint_verification: false + +# enable or disable client/server encryption. +client_encryption_options: + enabled: false + # If enabled and optional is set to true encrypted and unencrypted connections are handled. + optional: false + keystore: conf/.keystore + keystore_password: cassandra + # require_client_auth: false + # Set trustore and truststore_password if require_client_auth is true + # truststore: conf/.truststore + # truststore_password: cassandra + # More advanced defaults below: + # protocol: TLS + # algorithm: SunX509 + # store_type: JKS + # cipher_suites: [TLS_RSA_WITH_AES_128_CBC_SHA,TLS_RSA_WITH_AES_256_CBC_SHA,TLS_DHE_RSA_WITH_AES_128_CBC_SHA,TLS_DHE_RSA_WITH_AES_256_CBC_SHA,TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA,TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA] + +# internode_compression controls whether traffic between nodes is +# compressed. +# Can be: +# +# all +# all traffic is compressed +# +# dc +# traffic between different datacenters is compressed +# +# none +# nothing is compressed. +internode_compression: dc + +# Enable or disable tcp_nodelay for inter-dc communication. +# Disabling it will result in larger (but fewer) network packets being sent, +# reducing overhead from the TCP protocol itself, at the cost of increasing +# latency if you block for cross-datacenter responses. +inter_dc_tcp_nodelay: false + +# TTL for different trace types used during logging of the repair process. +tracetype_query_ttl: 86400 +tracetype_repair_ttl: 604800 + +# By default, Cassandra logs GC Pauses greater than 200 ms at INFO level +# This threshold can be adjusted to minimize logging if necessary +# gc_log_threshold_in_ms: 200 + +# If unset, all GC Pauses greater than gc_log_threshold_in_ms will log at +# INFO level +# UDFs (user defined functions) are disabled by default. +# As of Cassandra 3.0 there is a sandbox in place that should prevent execution of evil code. +enable_user_defined_functions: false + +# Enables scripted UDFs (JavaScript UDFs). +# Java UDFs are always enabled, if enable_user_defined_functions is true. +# Enable this option to be able to use UDFs with "language javascript" or any custom JSR-223 provider. +# This option has no effect, if enable_user_defined_functions is false. +enable_scripted_user_defined_functions: false + +# The default Windows kernel timer and scheduling resolution is 15.6ms for power conservation. +# Lowering this value on Windows can provide much tighter latency and better throughput, however +# some virtualized environments may see a negative performance impact from changing this setting +# below their system default. The sysinternals 'clockres' tool can confirm your system's default +# setting. +windows_timer_interval: 1 + + +# Enables encrypting data at-rest (on disk). Different key providers can be plugged in, but the default reads from +# a JCE-style keystore. A single keystore can hold multiple keys, but the one referenced by +# the "key_alias" is the only key that will be used for encrypt opertaions; previously used keys +# can still (and should!) be in the keystore and will be used on decrypt operations +# (to handle the case of key rotation). +# +# It is strongly recommended to download and install Java Cryptography Extension (JCE) +# Unlimited Strength Jurisdiction Policy Files for your version of the JDK. +# (current link: http://www.oracle.com/technetwork/java/javase/downloads/jce8-download-2133166.html) +# +# Currently, only the following file types are supported for transparent data encryption, although +# more are coming in future cassandra releases: commitlog, hints +transparent_data_encryption_options: + enabled: false + chunk_length_kb: 64 + cipher: AES/CBC/PKCS5Padding + key_alias: testing:1 + # CBC IV length for AES needs to be 16 bytes (which is also the default size) + # iv_length: 16 + key_provider: + - class_name: org.apache.cassandra.security.JKSKeyProvider + parameters: + - keystore: conf/.keystore + keystore_password: cassandra + store_type: JCEKS + key_password: cassandra + + +##################### +# SAFETY THRESHOLDS # +##################### + +# When executing a scan, within or across a partition, we need to keep the +# tombstones seen in memory so we can return them to the coordinator, which +# will use them to make sure other replicas also know about the deleted rows. +# With workloads that generate a lot of tombstones, this can cause performance +# problems and even exaust the server heap. +# (http://www.datastax.com/dev/blog/cassandra-anti-patterns-queues-and-queue-like-datasets) +# Adjust the thresholds here if you understand the dangers and want to +# scan more tombstones anyway. These thresholds may also be adjusted at runtime +# using the StorageService mbean. +tombstone_warn_threshold: 1000 +tombstone_failure_threshold: 100000 + +# Log WARN on any batch size exceeding this value. 5kb per batch by default. +# Caution should be taken on increasing the size of this threshold as it can lead to node instability. +batch_size_warn_threshold_in_kb: 5 + +# Fail any batch exceeding this value. 50kb (10x warn threshold) by default. +batch_size_fail_threshold_in_kb: 50 + +# Log WARN on any batches not of type LOGGED than span across more partitions than this limit +unlogged_batch_across_partitions_warn_threshold: 10 + +# Log a warning when compacting partitions larger than this value +compaction_large_partition_warning_threshold_mb: 100 + +# GC Pauses greater than gc_warn_threshold_in_ms will be logged at WARN level +# Adjust the threshold based on your application throughput requirement +# By default, Cassandra logs GC Pauses greater than 200 ms at INFO level +gc_warn_threshold_in_ms: 1000 + +# Maximum size of any value in SSTables. Safety measure to detect SSTable corruption +# early. Any value size larger than this threshold will result into marking an SSTable +# as corrupted. +# max_value_size_in_mb: 256 diff --git a/plugins/helper/database/connutil/cassandra.go b/plugins/helper/database/connutil/cassandra.go index 028c6814fb83..1babc3cbde8b 100644 --- a/plugins/helper/database/connutil/cassandra.go +++ b/plugins/helper/database/connutil/cassandra.go @@ -31,6 +31,7 @@ type CassandraConnectionProducer struct { Consistency string `json:"consistency" structs:"consistency" mapstructure:"consistency"` Initialized bool + Type string session *gocql.Session sync.Mutex } @@ -46,14 +47,14 @@ func (c *CassandraConnectionProducer) Initialize(conf map[string]interface{}, ve c.Initialized = true if verifyConnection { - if _, err := c.connection(); err != nil { + if _, err := c.Connection(); err != nil { return fmt.Errorf("error Initalizing Connection: %s", err) } } return nil } -func (c *CassandraConnectionProducer) connection() (interface{}, error) { +func (c *CassandraConnectionProducer) Connection() (interface{}, error) { if !c.Initialized { return nil, errNotInitialized } @@ -106,7 +107,7 @@ func (c *CassandraConnectionProducer) createSession() (*gocql.Session, error) { var tlsConfig *tls.Config if len(c.Certificate) > 0 || len(c.IssuingCA) > 0 { if len(c.Certificate) > 0 && len(c.PrivateKey) == 0 { - return nil, fmt.Errorf("Found certificate for TLS authentication but no private key") + return nil, fmt.Errorf("found certificate for TLS authentication but no private key") } certBundle := &certutil.CertBundle{} From 2faa08dfba307a5363be311a0aab5f1fe9e19f4d Mon Sep 17 00:00:00 2001 From: Calvin Leung Huang Date: Sun, 23 Apr 2017 00:04:05 -0400 Subject: [PATCH 088/152] Remove commented old method signature --- plugins/database/cassandra/cassandra.go | 1 - 1 file changed, 1 deletion(-) diff --git a/plugins/database/cassandra/cassandra.go b/plugins/database/cassandra/cassandra.go index 621d6e375282..15df0352e445 100644 --- a/plugins/database/cassandra/cassandra.go +++ b/plugins/database/cassandra/cassandra.go @@ -60,7 +60,6 @@ func (c *Cassandra) getConnection() (*gocql.Session, error) { return session.(*gocql.Session), nil } -// func (c *Cassandra) CreateUser(statements dbplugin.Statements, username, password, expiration string) error { func (c *Cassandra) CreateUser(statements dbplugin.Statements, usernamePrefix string, expiration time.Time) (username string, password string, err error) { // Grab the lock c.Lock() From f4ef3df4bd2a3085d84aebe9bc503a22fd77dd8d Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Mon, 24 Apr 2017 10:30:33 -0700 Subject: [PATCH 089/152] Update the builtin keys; move catalog to core; protect against unset plugin directory --- command/server.go | 19 ---------------- helper/builtinplugins/builtin.go | 24 +++++++-------------- helper/pluginutil/runner.go | 4 ++-- vault/logical_system.go | 13 +++++++---- vault/logical_system_test.go | 6 +++--- vault/plugin_catalog.go | 37 +++++++++++++++++++------------- vault/plugin_catalog_test.go | 6 +++--- 7 files changed, 47 insertions(+), 62 deletions(-) diff --git a/command/server.go b/command/server.go index ef9e3e3a0c97..9697c1dc853b 100644 --- a/command/server.go +++ b/command/server.go @@ -8,7 +8,6 @@ import ( "net/url" "os" "os/signal" - "path/filepath" "runtime" "sort" "strconv" @@ -21,7 +20,6 @@ import ( colorable "github.com/mattn/go-colorable" log "github.com/mgutz/logxi/v1" - homedir "github.com/mitchellh/go-homedir" "google.golang.org/grpc/grpclog" @@ -245,23 +243,6 @@ func (c *ServerCommand) Run(args []string) int { coreConfig.DevToken = devRootTokenID } - if config.PluginDirectory == "" { - homePath, err := homedir.Dir() - if err != nil { - c.Ui.Output(fmt.Sprintf( - "Error getting user's home directory: %v", err)) - return 1 - } - coreConfig.PluginDirectory = filepath.Join(homePath, "/.vault-plugins/") - err = os.Mkdir(coreConfig.PluginDirectory, 0700) - if err != nil && !os.IsExist(err) { - c.Ui.Output(fmt.Sprintf( - "Error making default plugin directory: %v", err)) - return 1 - } - - } - var disableClustering bool // Initialize the separate HA storage backend, if it exists diff --git a/helper/builtinplugins/builtin.go b/helper/builtinplugins/builtin.go index 9c51ae47898b..b61a51710d22 100644 --- a/helper/builtinplugins/builtin.go +++ b/helper/builtinplugins/builtin.go @@ -7,29 +7,21 @@ import ( type BuiltinFactory func() (interface{}, error) -var BuiltinPlugins *builtinPlugins = &builtinPlugins{ - plugins: map[string]BuiltinFactory{ - "mysql-database-plugin": mysql.New, - "postgresql-database-plugin": postgresql.New, - }, +var plugins map[string]BuiltinFactory = map[string]BuiltinFactory{ + "mysql-database-plugin": mysql.New, + "postgresql-database-plugin": postgresql.New, } -// The list of builtin plugins should not be changed by any other package, so we -// store them in an unexported variable in this unexported struct. -type builtinPlugins struct { - plugins map[string]BuiltinFactory -} - -func (b *builtinPlugins) Get(name string) (BuiltinFactory, bool) { - f, ok := b.plugins[name] +func Get(name string) (BuiltinFactory, bool) { + f, ok := plugins[name] return f, ok } -func (b *builtinPlugins) Keys() []string { - keys := make([]string, len(b.plugins)) +func Keys() []string { + keys := make([]string, len(plugins)) i := 0 - for k := range b.plugins { + for k := range plugins { keys[i] = k i++ } diff --git a/helper/pluginutil/runner.go b/helper/pluginutil/runner.go index 95de96a5a8c6..9963704e5464 100644 --- a/helper/pluginutil/runner.go +++ b/helper/pluginutil/runner.go @@ -12,8 +12,8 @@ import ( ) var ( - // PluginUnwrapTokenEnv is the ENV name used to pass unwrap tokens to the - // plugin. + // PluginUnwrapTokenEnv is the ENV name used to pass the configuration for + // enabling mlock PluginMlockEnabled = "VAULT_PLUGIN_MLOCK_ENABLED" ) diff --git a/vault/logical_system.go b/vault/logical_system.go index f43de9ef67f8..cd7113aa3078 100644 --- a/vault/logical_system.go +++ b/vault/logical_system.go @@ -710,13 +710,19 @@ func NewSystemBackend(core *Core, config *logical.BackendConfig) (logical.Backen Fields: map[string]*framework.FieldSchema{ "name": &framework.FieldSchema{ - Type: framework.TypeString, + Type: framework.TypeString, + Description: "The name of the plugin", }, "sha_256": &framework.FieldSchema{ Type: framework.TypeString, + Description: `The SHA256 sum of the executable used in the + command field. This should be HEX encoded.`, }, "command": &framework.FieldSchema{ Type: framework.TypeString, + Description: `The command used to start the plugin. The + executable defined in this command must exist in vault's + plugin directory.`, }, }, @@ -767,8 +773,7 @@ func (b *SystemBackend) handlePluginCatalogList(req *logical.Request, d *framewo return nil, err } - resp := logical.ListResponse(plugins) - return resp, nil + return logical.ListResponse(plugins), nil } func (b *SystemBackend) handlePluginCatalogUpdate(req *logical.Request, d *framework.FieldData) (*logical.Response, error) { @@ -2524,7 +2529,7 @@ This path responds to the following HTTP methods. `Configures the plugins known to vault`, ` This path responds to the following HTTP methods. - GET / + LIST / Returns a list of names of configured plugins. GET / diff --git a/vault/logical_system_test.go b/vault/logical_system_test.go index e9836946c3bf..ea940d540dd5 100644 --- a/vault/logical_system_test.go +++ b/vault/logical_system_test.go @@ -1129,8 +1129,8 @@ func TestSystemBackend_PluginCatalog_CRUD(t *testing.T) { t.Fatalf("err: %v", err) } - if len(resp.Data["keys"].([]string)) != len(builtinplugins.BuiltinPlugins.Keys()) { - t.Fatalf("Wrong number of plugins, got %d, expected %d", len(resp.Data["keys"].([]string)), len(builtinplugins.BuiltinPlugins.Keys())) + if len(resp.Data["keys"].([]string)) != len(builtinplugins.Keys()) { + t.Fatalf("Wrong number of plugins, got %d, expected %d", len(resp.Data["keys"].([]string)), len(builtinplugins.Keys())) } req = logical.TestRequest(t, logical.ReadOperation, "plugin-catalog/mysql-database-plugin") @@ -1143,7 +1143,7 @@ func TestSystemBackend_PluginCatalog_CRUD(t *testing.T) { Name: "mysql-database-plugin", Builtin: true, } - expectedBuiltin.BuiltinFactory, _ = builtinplugins.BuiltinPlugins.Get("mysql-database-plugin") + expectedBuiltin.BuiltinFactory, _ = builtinplugins.Get("mysql-database-plugin") p := resp.Data["plugin"].(*pluginutil.PluginRunner) if &(p.BuiltinFactory) == &(expectedBuiltin.BuiltinFactory) { diff --git a/vault/plugin_catalog.go b/vault/plugin_catalog.go index 598a16fac553..5d88873b3189 100644 --- a/vault/plugin_catalog.go +++ b/vault/plugin_catalog.go @@ -16,7 +16,8 @@ import ( ) var ( - pluginCatalogPrefix = "plugin-catalog/" + pluginCatalogPath = "core/plugin-catalog/" + ErrDirectoryNotConfigured = errors.New("could not set plugin, plugin directory is not configured") ) // PluginCatalog keeps a record of plugins known to vault. External plugins need @@ -31,7 +32,7 @@ type PluginCatalog struct { func (c *Core) setupPluginCatalog() error { c.pluginCatalog = &PluginCatalog{ - catalogView: c.systemBarrierView.SubView(pluginCatalogPrefix), + catalogView: NewBarrierView(c.barrier, pluginCatalogPath), directory: c.pluginDirectory, } @@ -45,22 +46,24 @@ func (c *PluginCatalog) Get(name string) (*pluginutil.PluginRunner, error) { c.lock.RLock() defer c.lock.RUnlock() - // Look for external plugins in the barrier - out, err := c.catalogView.Get(name) - if err != nil { - return nil, fmt.Errorf("failed to retrieve plugin \"%s\": %v", name, err) - } - if out != nil { - entry := new(pluginutil.PluginRunner) - if err := jsonutil.DecodeJSON(out.Value, entry); err != nil { - return nil, fmt.Errorf("failed to decode plugin entry: %v", err) + // If the directory isn't set only look for builtin plugins. + if c.directory != "" { + // Look for external plugins in the barrier + out, err := c.catalogView.Get(name) + if err != nil { + return nil, fmt.Errorf("failed to retrieve plugin \"%s\": %v", name, err) } + if out != nil { + entry := new(pluginutil.PluginRunner) + if err := jsonutil.DecodeJSON(out.Value, entry); err != nil { + return nil, fmt.Errorf("failed to decode plugin entry: %v", err) + } - return entry, nil + return entry, nil + } } - // Look for builtin plugins - if factory, ok := builtinplugins.BuiltinPlugins.Get(name); ok { + if factory, ok := builtinplugins.Get(name); ok { return &pluginutil.PluginRunner{ Name: name, Builtin: true, @@ -74,6 +77,10 @@ func (c *PluginCatalog) Get(name string) (*pluginutil.PluginRunner, error) { // Set registers a new external plugin with the catalog, or updates an existing // external plugin. It takes the name, command and SHA256 of the plugin. func (c *PluginCatalog) Set(name, command string, sha256 []byte) error { + if c.directory == "" { + return ErrDirectoryNotConfigured + } + c.lock.Lock() defer c.lock.Unlock() @@ -143,7 +150,7 @@ func (c *PluginCatalog) List() ([]string, error) { } // Get the keys for builtin plugins - builtinKeys := builtinplugins.BuiltinPlugins.Keys() + builtinKeys := builtinplugins.Keys() // Use a map to unique the two lists mapKeys := make(map[string]bool) diff --git a/vault/plugin_catalog_test.go b/vault/plugin_catalog_test.go index 57e864892bb0..6cfacda7e576 100644 --- a/vault/plugin_catalog_test.go +++ b/vault/plugin_catalog_test.go @@ -32,7 +32,7 @@ func TestPluginCatalog_CRUD(t *testing.T) { Name: "mysql-database-plugin", Builtin: true, } - expectedBuiltin.BuiltinFactory, _ = builtinplugins.BuiltinPlugins.Get("mysql-database-plugin") + expectedBuiltin.BuiltinFactory, _ = builtinplugins.Get("mysql-database-plugin") if &(p.BuiltinFactory) == &(expectedBuiltin.BuiltinFactory) { t.Fatal("expected BuiltinFactory did not match actual") @@ -90,7 +90,7 @@ func TestPluginCatalog_CRUD(t *testing.T) { Name: "mysql-database-plugin", Builtin: true, } - expectedBuiltin.BuiltinFactory, _ = builtinplugins.BuiltinPlugins.Get("mysql-database-plugin") + expectedBuiltin.BuiltinFactory, _ = builtinplugins.Get("mysql-database-plugin") if &(p.BuiltinFactory) == &(expectedBuiltin.BuiltinFactory) { t.Fatal("expected BuiltinFactory did not match actual") @@ -113,7 +113,7 @@ func TestPluginCatalog_List(t *testing.T) { core.pluginCatalog.directory = sym // Get builtin plugins and sort them - builtinKeys := builtinplugins.BuiltinPlugins.Keys() + builtinKeys := builtinplugins.Keys() sort.Strings(builtinKeys) // List only builtin plugins From 707e6caf0cd1ea6570cc154cb6abdf10d2dc7fd5 Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Mon, 24 Apr 2017 11:35:32 -0700 Subject: [PATCH 090/152] Update path for the plugin catalog in logical system --- vault/logical_system.go | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/vault/logical_system.go b/vault/logical_system.go index cd7113aa3078..843483449ffb 100644 --- a/vault/logical_system.go +++ b/vault/logical_system.go @@ -63,7 +63,7 @@ func NewSystemBackend(core *Core, config *logical.BackendConfig) (logical.Backen "replication/reindex", "rotate", "config/auditing/*", - "plugin-catalog/*", + "plugins/catalog/*", }, Unauthenticated: []string{ @@ -694,7 +694,7 @@ func NewSystemBackend(core *Core, config *logical.BackendConfig) (logical.Backen HelpDescription: strings.TrimSpace(sysHelp["audited-headers"][1]), }, &framework.Path{ - Pattern: "plugin-catalog/$", + Pattern: "plugins/catalog/$", Fields: map[string]*framework.FieldSchema{}, @@ -706,7 +706,7 @@ func NewSystemBackend(core *Core, config *logical.BackendConfig) (logical.Backen HelpDescription: strings.TrimSpace(sysHelp["plugin-catalog"][1]), }, &framework.Path{ - Pattern: "plugin-catalog/(?P.+)", + Pattern: "plugins/catalog/(?P.+)", Fields: map[string]*framework.FieldSchema{ "name": &framework.FieldSchema{ @@ -2525,7 +2525,7 @@ This path responds to the following HTTP methods. "Lists the headers configured to be audited.", `Returns a list of headers that have been configured to be audited.`, }, - "plugin-catalog": { + "plugins/catalog": { `Configures the plugins known to vault`, ` This path responds to the following HTTP methods. From 4cda9ea3fe1e193c3aeee67c0ff1c23e9c3636ca Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Mon, 24 Apr 2017 12:15:01 -0700 Subject: [PATCH 091/152] Update the ResponseWrapData function to return a wrapping.ResponseWrapInfo object --- audit/hashstructure.go | 3 ++- audit/hashstructure_test.go | 7 ++++--- helper/pluginutil/mlock.go | 23 +++++++++++++++++++++++ helper/pluginutil/runner.go | 21 ++------------------- helper/pluginutil/tls.go | 7 +++++-- helper/wrapping/wrapinfo.go | 23 +++++++++++++++++++++++ logical/response.go | 26 +++----------------------- logical/system_view.go | 7 ++++--- vault/dynamic_system_view.go | 11 ++++++----- vault/logical_system.go | 3 ++- vault/request_handling.go | 5 +++-- 11 files changed, 77 insertions(+), 59 deletions(-) create mode 100644 helper/pluginutil/mlock.go create mode 100644 helper/wrapping/wrapinfo.go diff --git a/audit/hashstructure.go b/audit/hashstructure.go index 8d0fd7c6c7c6..ea0899ee9731 100644 --- a/audit/hashstructure.go +++ b/audit/hashstructure.go @@ -5,6 +5,7 @@ import ( "strings" "github.com/hashicorp/vault/helper/salt" + "github.com/hashicorp/vault/helper/wrapping" "github.com/hashicorp/vault/logical" "github.com/mitchellh/copystructure" "github.com/mitchellh/reflectwalk" @@ -84,7 +85,7 @@ func Hash(salter *salt.Salt, raw interface{}) error { s.Data = data.(map[string]interface{}) - case *logical.ResponseWrapInfo: + case *wrapping.ResponseWrapInfo: if s == nil { return nil } diff --git a/audit/hashstructure_test.go b/audit/hashstructure_test.go index 5fefa0fa9158..6916d0d3a308 100644 --- a/audit/hashstructure_test.go +++ b/audit/hashstructure_test.go @@ -9,6 +9,7 @@ import ( "github.com/hashicorp/vault/helper/certutil" "github.com/hashicorp/vault/helper/salt" + "github.com/hashicorp/vault/helper/wrapping" "github.com/hashicorp/vault/logical" "github.com/mitchellh/copystructure" ) @@ -69,7 +70,7 @@ func TestCopy_response(t *testing.T) { Data: map[string]interface{}{ "foo": "bar", }, - WrapInfo: &logical.ResponseWrapInfo{ + WrapInfo: &wrapping.ResponseWrapInfo{ TTL: 60, Token: "foo", CreationTime: time.Now(), @@ -140,7 +141,7 @@ func TestHash(t *testing.T) { Data: map[string]interface{}{ "foo": "bar", }, - WrapInfo: &logical.ResponseWrapInfo{ + WrapInfo: &wrapping.ResponseWrapInfo{ TTL: 60, Token: "bar", CreationTime: now, @@ -151,7 +152,7 @@ func TestHash(t *testing.T) { Data: map[string]interface{}{ "foo": "hmac-sha256:f9320baf0249169e73850cd6156ded0106e2bb6ad8cab01b7bbbebe6d1065317", }, - WrapInfo: &logical.ResponseWrapInfo{ + WrapInfo: &wrapping.ResponseWrapInfo{ TTL: 60, Token: "hmac-sha256:f9320baf0249169e73850cd6156ded0106e2bb6ad8cab01b7bbbebe6d1065317", CreationTime: now, diff --git a/helper/pluginutil/mlock.go b/helper/pluginutil/mlock.go new file mode 100644 index 000000000000..dd9115a89a29 --- /dev/null +++ b/helper/pluginutil/mlock.go @@ -0,0 +1,23 @@ +package pluginutil + +import ( + "os" + + "github.com/hashicorp/vault/helper/mlock" +) + +var ( + // PluginUnwrapTokenEnv is the ENV name used to pass the configuration for + // enabling mlock + PluginMlockEnabled = "VAULT_PLUGIN_MLOCK_ENABLED" +) + +// OptionallyEnableMlock determines if mlock should be called, and if so enables +// mlock. +func OptionallyEnableMlock() error { + if os.Getenv(PluginMlockEnabled) == "true" { + return mlock.LockMemory() + } + + return nil +} diff --git a/helper/pluginutil/runner.go b/helper/pluginutil/runner.go index 9963704e5464..539c3b448561 100644 --- a/helper/pluginutil/runner.go +++ b/helper/pluginutil/runner.go @@ -3,18 +3,11 @@ package pluginutil import ( "crypto/sha256" "fmt" - "os" "os/exec" "time" plugin "github.com/hashicorp/go-plugin" - "github.com/hashicorp/vault/helper/mlock" -) - -var ( - // PluginUnwrapTokenEnv is the ENV name used to pass the configuration for - // enabling mlock - PluginMlockEnabled = "VAULT_PLUGIN_MLOCK_ENABLED" + "github.com/hashicorp/vault/helper/wrapping" ) // Looker defines the plugin Lookup function that looks into the plugin catalog @@ -27,7 +20,7 @@ type Looker interface { // metadata needed to run a plugin process. This includes looking up Mlock // configuration and wrapping data in a respose wrapped token. type Wrapper interface { - ResponseWrapData(data map[string]interface{}, ttl time.Duration, jwt bool) (string, error) + ResponseWrapData(data map[string]interface{}, ttl time.Duration, jwt bool) (*wrapping.ResponseWrapInfo, error) MlockDisabled() bool } @@ -97,13 +90,3 @@ func (r *PluginRunner) Run(wrapper Wrapper, pluginMap map[string]plugin.Plugin, return client, nil } - -// OptionallyEnableMlock determines if mlock should be called, and if so enables -// mlock. -func OptionallyEnableMlock() error { - if os.Getenv(PluginMlockEnabled) == "true" { - return mlock.LockMemory() - } - - return nil -} diff --git a/helper/pluginutil/tls.go b/helper/pluginutil/tls.go index d4c0946e4fbf..ee0c54d89d4c 100644 --- a/helper/pluginutil/tls.go +++ b/helper/pluginutil/tls.go @@ -103,12 +103,15 @@ func WrapServerConfig(sys Wrapper, certBytes []byte, key *ecdsa.PrivateKey) (str return "", err } - wrapToken, err := sys.ResponseWrapData(map[string]interface{}{ + wrapInfo, err := sys.ResponseWrapData(map[string]interface{}{ "ServerCert": certBytes, "ServerKey": rawKey, }, time.Second*10, true) + if err != nil { + return "", err + } - return wrapToken, err + return wrapInfo.Token, nil } // VaultPluginTLSProvider is run inside a plugin and retrives the response diff --git a/helper/wrapping/wrapinfo.go b/helper/wrapping/wrapinfo.go new file mode 100644 index 000000000000..a27219b8a919 --- /dev/null +++ b/helper/wrapping/wrapinfo.go @@ -0,0 +1,23 @@ +package wrapping + +import "time" + +type ResponseWrapInfo struct { + // Setting to non-zero specifies that the response should be wrapped. + // Specifies the desired TTL of the wrapping token. + TTL time.Duration `json:"ttl" structs:"ttl" mapstructure:"ttl"` + + // The token containing the wrapped response + Token string `json:"token" structs:"token" mapstructure:"token"` + + // The creation time. This can be used with the TTL to figure out an + // expected expiration. + CreationTime time.Time `json:"creation_time" structs:"creation_time" mapstructure:"cration_time"` + + // If the contained response is the output of a token creation call, the + // created token's accessor will be accessible here + WrappedAccessor string `json:"wrapped_accessor" structs:"wrapped_accessor" mapstructure:"wrapped_accessor"` + + // The format to use. This doesn't get returned, it's only internal. + Format string `json:"format" structs:"format" mapstructure:"format"` +} diff --git a/logical/response.go b/logical/response.go index ee6bfe1e27c2..2a4646a2ca81 100644 --- a/logical/response.go +++ b/logical/response.go @@ -4,8 +4,8 @@ import ( "errors" "fmt" "reflect" - "time" + "github.com/hashicorp/vault/helper/wrapping" "github.com/mitchellh/copystructure" ) @@ -28,26 +28,6 @@ const ( HTTPStatusCode = "http_status_code" ) -type ResponseWrapInfo struct { - // Setting to non-zero specifies that the response should be wrapped. - // Specifies the desired TTL of the wrapping token. - TTL time.Duration `json:"ttl" structs:"ttl" mapstructure:"ttl"` - - // The token containing the wrapped response - Token string `json:"token" structs:"token" mapstructure:"token"` - - // The creation time. This can be used with the TTL to figure out an - // expected expiration. - CreationTime time.Time `json:"creation_time" structs:"creation_time" mapstructure:"cration_time"` - - // If the contained response is the output of a token creation call, the - // created token's accessor will be accessible here - WrappedAccessor string `json:"wrapped_accessor" structs:"wrapped_accessor" mapstructure:"wrapped_accessor"` - - // The format to use. This doesn't get returned, it's only internal. - Format string `json:"format" structs:"format" mapstructure:"format"` -} - // Response is a struct that stores the response of a request. // It is used to abstract the details of the higher level request protocol. type Response struct { @@ -78,7 +58,7 @@ type Response struct { warnings []string `json:"warnings" structs:"warnings" mapstructure:"warnings"` // Information for wrapping the response in a cubbyhole - WrapInfo *ResponseWrapInfo `json:"wrap_info" structs:"wrap_info" mapstructure:"wrap_info"` + WrapInfo *wrapping.ResponseWrapInfo `json:"wrap_info" structs:"wrap_info" mapstructure:"wrap_info"` } func init() { @@ -123,7 +103,7 @@ func init() { if err != nil { return nil, fmt.Errorf("error copying WrapInfo: %v", err) } - ret.WrapInfo = retWrapInfo.(*ResponseWrapInfo) + ret.WrapInfo = retWrapInfo.(*wrapping.ResponseWrapInfo) } return &ret, nil diff --git a/logical/system_view.go b/logical/system_view.go index b6ab14b1fc01..e13b63f28749 100644 --- a/logical/system_view.go +++ b/logical/system_view.go @@ -6,6 +6,7 @@ import ( "github.com/hashicorp/vault/helper/consts" "github.com/hashicorp/vault/helper/pluginutil" + "github.com/hashicorp/vault/helper/wrapping" ) // SystemView exposes system configuration information in a safe way @@ -42,7 +43,7 @@ type SystemView interface { // ResponseWrapData wraps the given data in a cubbyhole and returns the // token used to unwrap. - ResponseWrapData(data map[string]interface{}, ttl time.Duration, jwt bool) (string, error) + ResponseWrapData(data map[string]interface{}, ttl time.Duration, jwt bool) (*wrapping.ResponseWrapInfo, error) // LookupPlugin looks into the plugin catalog for a plugin with the given // name. Returns a PluginRunner or an error if a plugin can not be found. @@ -87,8 +88,8 @@ func (d StaticSystemView) ReplicationState() consts.ReplicationState { return d.ReplicationStateVal } -func (d StaticSystemView) ResponseWrapData(data map[string]interface{}, ttl time.Duration, jwt bool) (string, error) { - return "", errors.New("ResponseWrapData is not implemented in StaticSystemView") +func (d StaticSystemView) ResponseWrapData(data map[string]interface{}, ttl time.Duration, jwt bool) (*wrapping.ResponseWrapInfo, error) { + return nil, errors.New("ResponseWrapData is not implemented in StaticSystemView") } func (d StaticSystemView) LookupPlugin(name string) (*pluginutil.PluginRunner, error) { diff --git a/vault/dynamic_system_view.go b/vault/dynamic_system_view.go index eb99f29c62b0..9302bfbc1479 100644 --- a/vault/dynamic_system_view.go +++ b/vault/dynamic_system_view.go @@ -5,6 +5,7 @@ import ( "github.com/hashicorp/vault/helper/consts" "github.com/hashicorp/vault/helper/pluginutil" + "github.com/hashicorp/vault/helper/wrapping" "github.com/hashicorp/vault/logical" ) @@ -91,14 +92,14 @@ func (d dynamicSystemView) ReplicationState() consts.ReplicationState { // ResponseWrapData wraps the given data in a cubbyhole and returns the // token used to unwrap. -func (d dynamicSystemView) ResponseWrapData(data map[string]interface{}, ttl time.Duration, jwt bool) (string, error) { +func (d dynamicSystemView) ResponseWrapData(data map[string]interface{}, ttl time.Duration, jwt bool) (*wrapping.ResponseWrapInfo, error) { req := &logical.Request{ Operation: logical.CreateOperation, - Path: "sys/init", + Path: "sys/wrapping/wrap", } resp := &logical.Response{ - WrapInfo: &logical.ResponseWrapInfo{ + WrapInfo: &wrapping.ResponseWrapInfo{ TTL: ttl, }, Data: data, @@ -110,10 +111,10 @@ func (d dynamicSystemView) ResponseWrapData(data map[string]interface{}, ttl tim _, err := d.core.wrapInCubbyhole(req, resp) if err != nil { - return "", err + return nil, err } - return resp.WrapInfo.Token, nil + return resp.WrapInfo, nil } // LookupPlugin looks for a plugin with the given name in the plugin catalog. It diff --git a/vault/logical_system.go b/vault/logical_system.go index 843483449ffb..109100090c4f 100644 --- a/vault/logical_system.go +++ b/vault/logical_system.go @@ -11,6 +11,7 @@ import ( "github.com/hashicorp/vault/helper/consts" "github.com/hashicorp/vault/helper/parseutil" + "github.com/hashicorp/vault/helper/wrapping" "github.com/hashicorp/vault/logical" "github.com/hashicorp/vault/logical/framework" "github.com/mitchellh/mapstructure" @@ -2075,7 +2076,7 @@ func (b *SystemBackend) handleWrappingRewrap( Data: map[string]interface{}{ "response": response, }, - WrapInfo: &logical.ResponseWrapInfo{ + WrapInfo: &wrapping.ResponseWrapInfo{ TTL: time.Duration(creationTTL), }, }, nil diff --git a/vault/request_handling.go b/vault/request_handling.go index ad37b5aee0b6..1326ef518ac7 100644 --- a/vault/request_handling.go +++ b/vault/request_handling.go @@ -11,6 +11,7 @@ import ( "github.com/hashicorp/vault/helper/jsonutil" "github.com/hashicorp/vault/helper/policyutil" "github.com/hashicorp/vault/helper/strutil" + "github.com/hashicorp/vault/helper/wrapping" "github.com/hashicorp/vault/logical" ) @@ -216,7 +217,7 @@ func (c *Core) handleRequest(req *logical.Request) (retResp *logical.Response, r } if wrapTTL > 0 { - resp.WrapInfo = &logical.ResponseWrapInfo{ + resp.WrapInfo = &wrapping.ResponseWrapInfo{ TTL: wrapTTL, Format: wrapFormat, } @@ -361,7 +362,7 @@ func (c *Core) handleLoginRequest(req *logical.Request) (*logical.Response, *log } if wrapTTL > 0 { - resp.WrapInfo = &logical.ResponseWrapInfo{ + resp.WrapInfo = &wrapping.ResponseWrapInfo{ TTL: wrapTTL, Format: wrapFormat, } From 4c306bd76e23ffc0b552123c2003ad85cc9b8a59 Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Mon, 24 Apr 2017 12:21:49 -0700 Subject: [PATCH 092/152] Change MlockDisabled to MlockEnabled --- helper/pluginutil/runner.go | 11 ++++------- logical/system_view.go | 11 ++++++----- vault/core.go | 4 ++-- vault/dynamic_system_view.go | 6 +++--- 4 files changed, 15 insertions(+), 17 deletions(-) diff --git a/helper/pluginutil/runner.go b/helper/pluginutil/runner.go index 539c3b448561..6a8df73855d8 100644 --- a/helper/pluginutil/runner.go +++ b/helper/pluginutil/runner.go @@ -21,7 +21,7 @@ type Looker interface { // configuration and wrapping data in a respose wrapped token. type Wrapper interface { ResponseWrapData(data map[string]interface{}, ttl time.Duration, jwt bool) (*wrapping.ResponseWrapInfo, error) - MlockDisabled() bool + MlockEnabled() bool } // LookWrapper defines the functions for both Looker and Wrapper @@ -63,17 +63,14 @@ func (r *PluginRunner) Run(wrapper Wrapper, pluginMap map[string]plugin.Plugin, return nil, err } - mlock := "true" - if wrapper.MlockDisabled() { - mlock = "false" - } - cmd := exec.Command(r.Command, r.Args...) cmd.Env = append(cmd.Env, env...) // Add the response wrap token to the ENV of the plugin cmd.Env = append(cmd.Env, fmt.Sprintf("%s=%s", PluginUnwrapTokenEnv, wrapToken)) // Add the mlock setting to the ENV of the plugin - cmd.Env = append(cmd.Env, fmt.Sprintf("%s=%s", PluginMlockEnabled, mlock)) + if wrapper.MlockEnabled() { + cmd.Env = append(cmd.Env, fmt.Sprintf("%s=%s", PluginMlockEnabled, "true")) + } secureConfig := &plugin.SecureConfig{ Checksum: r.Sha256, diff --git a/logical/system_view.go b/logical/system_view.go index e13b63f28749..175edc0f9a40 100644 --- a/logical/system_view.go +++ b/logical/system_view.go @@ -49,8 +49,9 @@ type SystemView interface { // name. Returns a PluginRunner or an error if a plugin can not be found. LookupPlugin(string) (*pluginutil.PluginRunner, error) - // MlockDisabled returns the configuration setting for DisableMlock. - MlockDisabled() bool + // MlockEnabled returns the configuration setting for Enableing mlock on + // plugins. + MlockEnabled() bool } type StaticSystemView struct { @@ -60,7 +61,7 @@ type StaticSystemView struct { TaintedVal bool CachingDisabledVal bool Primary bool - DisableMlock bool + EnableMlock bool ReplicationStateVal consts.ReplicationState } @@ -96,6 +97,6 @@ func (d StaticSystemView) LookupPlugin(name string) (*pluginutil.PluginRunner, e return nil, errors.New("LookupPlugin is not implemented in StaticSystemView") } -func (d StaticSystemView) MlockDisabled() bool { - return d.DisableMlock +func (d StaticSystemView) MlockEnabled() bool { + return d.EnableMlock } diff --git a/vault/core.go b/vault/core.go index 01ab49f752e3..260ce096e08e 100644 --- a/vault/core.go +++ b/vault/core.go @@ -338,7 +338,7 @@ type Core struct { // pluginCatalog is used to manage plugin configurations pluginCatalog *PluginCatalog - disableMlock bool + enableMlock bool } // CoreConfig is used to parameterize a core @@ -441,7 +441,7 @@ func NewCore(conf *CoreConfig) (*Core, error) { clusterName: conf.ClusterName, clusterListenerShutdownCh: make(chan struct{}), clusterListenerShutdownSuccessCh: make(chan struct{}), - disableMlock: conf.DisableMlock, + enableMlock: !conf.DisableMlock, } // Wrap the physical backend in a cache layer if enabled and not already wrapped diff --git a/vault/dynamic_system_view.go b/vault/dynamic_system_view.go index 9302bfbc1479..edac20140251 100644 --- a/vault/dynamic_system_view.go +++ b/vault/dynamic_system_view.go @@ -123,7 +123,7 @@ func (d dynamicSystemView) LookupPlugin(name string) (*pluginutil.PluginRunner, return d.core.pluginCatalog.Get(name) } -// MlockDisabled returns the configuration setting "DisableMlock". -func (d dynamicSystemView) MlockDisabled() bool { - return d.core.disableMlock +// MlockEnabled returns the configuration setting for enabling mlock on plugins. +func (d dynamicSystemView) MlockEnabled() bool { + return d.core.enableMlock } From 7e3f5e69852fe5feb99c67451b7ddbdc583c8ddd Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Mon, 24 Apr 2017 12:47:40 -0700 Subject: [PATCH 093/152] Update root paths test --- vault/logical_system_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vault/logical_system_test.go b/vault/logical_system_test.go index ea940d540dd5..9aae06778eeb 100644 --- a/vault/logical_system_test.go +++ b/vault/logical_system_test.go @@ -32,7 +32,7 @@ func TestSystemBackend_RootPaths(t *testing.T) { "replication/reindex", "rotate", "config/auditing/*", - "plugin-catalog/*", + "plugins/catalog/*", } b := testSystemBackend(t) From 4315e689715aa77b16d93ab00278790e9309ccde Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Mon, 24 Apr 2017 13:48:46 -0700 Subject: [PATCH 094/152] Fix test --- vault/logical_system_test.go | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/vault/logical_system_test.go b/vault/logical_system_test.go index 9aae06778eeb..7bedf7cd6645 100644 --- a/vault/logical_system_test.go +++ b/vault/logical_system_test.go @@ -1123,7 +1123,7 @@ func TestSystemBackend_PluginCatalog_CRUD(t *testing.T) { } c.pluginCatalog.directory = sym - req := logical.TestRequest(t, logical.ListOperation, "plugin-catalog/") + req := logical.TestRequest(t, logical.ListOperation, "plugins/catalog/") resp, err := b.HandleRequest(req) if err != nil { t.Fatalf("err: %v", err) @@ -1133,7 +1133,7 @@ func TestSystemBackend_PluginCatalog_CRUD(t *testing.T) { t.Fatalf("Wrong number of plugins, got %d, expected %d", len(resp.Data["keys"].([]string)), len(builtinplugins.Keys())) } - req = logical.TestRequest(t, logical.ReadOperation, "plugin-catalog/mysql-database-plugin") + req = logical.TestRequest(t, logical.ReadOperation, "plugins/catalog/mysql-database-plugin") resp, err = b.HandleRequest(req) if err != nil { t.Fatalf("err: %v", err) @@ -1164,7 +1164,7 @@ func TestSystemBackend_PluginCatalog_CRUD(t *testing.T) { defer file.Close() command := fmt.Sprintf("%s --test", filepath.Base(file.Name())) - req = logical.TestRequest(t, logical.UpdateOperation, "plugin-catalog/test-plugin") + req = logical.TestRequest(t, logical.UpdateOperation, "plugins/catalog/test-plugin") req.Data["sha_256"] = hex.EncodeToString([]byte{'1'}) req.Data["command"] = command resp, err = b.HandleRequest(req) @@ -1172,7 +1172,7 @@ func TestSystemBackend_PluginCatalog_CRUD(t *testing.T) { t.Fatalf("err: %v", err) } - req = logical.TestRequest(t, logical.ReadOperation, "plugin-catalog/test-plugin") + req = logical.TestRequest(t, logical.ReadOperation, "plugins/catalog/test-plugin") resp, err = b.HandleRequest(req) if err != nil { t.Fatalf("err: %v", err) @@ -1190,13 +1190,13 @@ func TestSystemBackend_PluginCatalog_CRUD(t *testing.T) { } // Delete plugin - req = logical.TestRequest(t, logical.DeleteOperation, "plugin-catalog/test-plugin") + req = logical.TestRequest(t, logical.DeleteOperation, "plugins/catalog/test-plugin") resp, err = b.HandleRequest(req) if err != nil { t.Fatalf("err: %v", err) } - req = logical.TestRequest(t, logical.ReadOperation, "plugin-catalog/test-plugin") + req = logical.TestRequest(t, logical.ReadOperation, "plugins/catalog/test-plugin") resp, err = b.HandleRequest(req) if err == nil { t.Fatalf("expected error, plugin not deleted correctly") From f6b96ccfa2c7ca2a2e73690003c146f069ab0900 Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Mon, 24 Apr 2017 13:59:12 -0700 Subject: [PATCH 095/152] s/DatabaseType/Database/ --- builtin/logical/database/backend.go | 8 ++++---- builtin/logical/database/dbplugin/client.go | 4 ++-- .../database/dbplugin/databasemiddleware.go | 8 ++++---- builtin/logical/database/dbplugin/plugin.go | 14 +++++++------- builtin/logical/database/dbplugin/server.go | 10 +++++----- builtin/logical/database/path_config_connection.go | 2 +- plugins/database/mssql/mssql.go | 2 +- plugins/helper/database/connutil/connutil.go | 2 +- plugins/helper/database/credsutil/credsutil.go | 2 +- 9 files changed, 26 insertions(+), 26 deletions(-) diff --git a/builtin/logical/database/backend.go b/builtin/logical/database/backend.go index e57fa19c180d..7d6ffe9c9873 100644 --- a/builtin/logical/database/backend.go +++ b/builtin/logical/database/backend.go @@ -41,12 +41,12 @@ func Backend(conf *logical.BackendConfig) *databaseBackend { } b.logger = conf.Logger - b.connections = make(map[string]dbplugin.DatabaseType) + b.connections = make(map[string]dbplugin.Database) return &b } type databaseBackend struct { - connections map[string]dbplugin.DatabaseType + connections map[string]dbplugin.Database logger log.Logger *framework.Backend @@ -62,13 +62,13 @@ func (b *databaseBackend) closeAllDBs() { db.Close() } - b.connections = nil + b.connections = make(map[string]dbplugin.Database) } // This function is used to retrieve a database object either from the cached // connection map or by using the database config in storage. The caller of this // function needs to hold the backend's lock. -func (b *databaseBackend) getOrCreateDBObj(s logical.Storage, name string) (dbplugin.DatabaseType, error) { +func (b *databaseBackend) getOrCreateDBObj(s logical.Storage, name string) (dbplugin.Database, error) { // if the object already is built and cached, return it db, ok := b.connections[name] if ok { diff --git a/builtin/logical/database/dbplugin/client.go b/builtin/logical/database/dbplugin/client.go index 93db86595a1c..8cfc3aad00a1 100644 --- a/builtin/logical/database/dbplugin/client.go +++ b/builtin/logical/database/dbplugin/client.go @@ -29,7 +29,7 @@ func (dc *DatabasePluginClient) Close() error { // newPluginClient returns a databaseRPCClient with a connection to a running // plugin. The client is wrapped in a DatabasePluginClient object to ensure the // plugin is killed on call of Close(). -func newPluginClient(sys pluginutil.Wrapper, pluginRunner *pluginutil.PluginRunner) (DatabaseType, error) { +func newPluginClient(sys pluginutil.Wrapper, pluginRunner *pluginutil.PluginRunner) (Database, error) { // pluginMap is the map of plugins we can dispense. var pluginMap = map[string]plugin.Plugin{ "database": new(DatabasePlugin), @@ -65,7 +65,7 @@ func newPluginClient(sys pluginutil.Wrapper, pluginRunner *pluginutil.PluginRunn // ---- RPC client domain ---- -// databasePluginRPCClient implements DatabaseType and is used on the client to +// databasePluginRPCClient implements Database and is used on the client to // make RPC calls to a plugin. type databasePluginRPCClient struct { client *rpc.Client diff --git a/builtin/logical/database/dbplugin/databasemiddleware.go b/builtin/logical/database/dbplugin/databasemiddleware.go index e28a8741e43c..9ab35b740fee 100644 --- a/builtin/logical/database/dbplugin/databasemiddleware.go +++ b/builtin/logical/database/dbplugin/databasemiddleware.go @@ -9,10 +9,10 @@ import ( // ---- Tracing Middleware Domain ---- -// databaseTracingMiddleware wraps a implementation of DatabaseType and executes +// databaseTracingMiddleware wraps a implementation of Database and executes // trace logging on function call. type databaseTracingMiddleware struct { - next DatabaseType + next Database logger log.Logger typeStr string @@ -79,10 +79,10 @@ func (mw *databaseTracingMiddleware) Close() (err error) { // ---- Metrics Middleware Domain ---- -// databaseMetricsMiddleware wraps an implementation of DatabaseTypes and on +// databaseMetricsMiddleware wraps an implementation of Databases and on // function call logs metrics about this instance. type databaseMetricsMiddleware struct { - next DatabaseType + next Database typeStr string } diff --git a/builtin/logical/database/dbplugin/plugin.go b/builtin/logical/database/dbplugin/plugin.go index 9a6691fbabad..21812423c1a2 100644 --- a/builtin/logical/database/dbplugin/plugin.go +++ b/builtin/logical/database/dbplugin/plugin.go @@ -10,8 +10,8 @@ import ( log "github.com/mgutz/logxi/v1" ) -// DatabaseType is the interface that all database objects must implement. -type DatabaseType interface { +// Database is the interface that all database objects must implement. +type Database interface { Type() (string, error) CreateUser(statements Statements, usernamePrefix string, expiration time.Time) (username string, password string, err error) RenewUser(statements Statements, username string, expiration time.Time) error @@ -31,24 +31,24 @@ type Statements struct { // PluginFactory is used to build plugin database types. It wraps the database // object in a logging and metrics middleware. -func PluginFactory(pluginName string, sys pluginutil.LookWrapper, logger log.Logger) (DatabaseType, error) { +func PluginFactory(pluginName string, sys pluginutil.LookWrapper, logger log.Logger) (Database, error) { // Look for plugin in the plugin catalog pluginRunner, err := sys.LookupPlugin(pluginName) if err != nil { return nil, err } - var db DatabaseType + var db Database if pluginRunner.Builtin { // Plugin is builtin so we can retrieve an instance of the interface - // from the pluginRunner. Then cast it to a DatabaseType. + // from the pluginRunner. Then cast it to a Database. dbRaw, err := pluginRunner.BuiltinFactory() if err != nil { return nil, fmt.Errorf("error getting plugin type: %s", err) } var ok bool - db, ok = dbRaw.(DatabaseType) + db, ok = dbRaw.(Database) if !ok { return nil, fmt.Errorf("unsuported database type: %s", pluginName) } @@ -95,7 +95,7 @@ var handshakeConfig = plugin.HandshakeConfig{ // DatabasePlugin implements go-plugin's Plugin interface. It has methods for // retrieving a server and a client instance of the plugin. type DatabasePlugin struct { - impl DatabaseType + impl Database } func (d DatabasePlugin) Server(*plugin.MuxBroker) (interface{}, error) { diff --git a/builtin/logical/database/dbplugin/server.go b/builtin/logical/database/dbplugin/server.go index 3a3e233946ec..04cc3d7e9041 100644 --- a/builtin/logical/database/dbplugin/server.go +++ b/builtin/logical/database/dbplugin/server.go @@ -8,9 +8,9 @@ import ( ) // NewPluginServer is called from within a plugin and wraps the provided -// DatabaseType implementation in a databasePluginRPCServer object and starts a +// Database implementation in a databasePluginRPCServer object and starts a // RPC server. -func NewPluginServer(db DatabaseType) { +func NewPluginServer(db Database) { dbPlugin := &DatabasePlugin{ impl: db, } @@ -35,10 +35,10 @@ func NewPluginServer(db DatabaseType) { // ---- RPC server domain ---- -// databasePluginRPCServer implements an RPC version of DatabaseType and is run -// inside a plugin. It wraps an underlying implementation of DatabaseType. +// databasePluginRPCServer implements an RPC version of Database and is run +// inside a plugin. It wraps an underlying implementation of Database. type databasePluginRPCServer struct { - impl DatabaseType + impl Database } func (ds *databasePluginRPCServer) Type(_ struct{}, resp *string) error { diff --git a/builtin/logical/database/path_config_connection.go b/builtin/logical/database/path_config_connection.go index 2a0022b4d869..f154ae1643d2 100644 --- a/builtin/logical/database/path_config_connection.go +++ b/builtin/logical/database/path_config_connection.go @@ -16,7 +16,7 @@ var ( respErrEmptyName = logical.ErrorResponse("Empty name attribute given") ) -// DatabaseConfig is used by the Factory function to configure a DatabaseType +// DatabaseConfig is used by the Factory function to configure a Database // object. type DatabaseConfig struct { PluginName string `json:"plugin_name" structs:"plugin_name" mapstructure:"plugin_name"` diff --git a/plugins/database/mssql/mssql.go b/plugins/database/mssql/mssql.go index b0e0ab6d4198..54f2a9711abc 100644 --- a/plugins/database/mssql/mssql.go +++ b/plugins/database/mssql/mssql.go @@ -15,7 +15,7 @@ import ( const msSQLTypeName = "mssql" -// MSSQL is an implementation of DatabaseType interface +// MSSQL is an implementation of Database interface type MSSQL struct { connutil.ConnectionProducer credsutil.CredentialsProducer diff --git a/plugins/helper/database/connutil/connutil.go b/plugins/helper/database/connutil/connutil.go index 6de3299e3899..c43691c6164d 100644 --- a/plugins/helper/database/connutil/connutil.go +++ b/plugins/helper/database/connutil/connutil.go @@ -9,7 +9,7 @@ var ( errNotInitialized = errors.New("connection has not been initalized") ) -// ConnectionProducer can be used as an embeded interface in the DatabaseType +// ConnectionProducer can be used as an embeded interface in the Database // definition. It implements the methods dealing with individual database // connections and is used in all the builtin database types. type ConnectionProducer interface { diff --git a/plugins/helper/database/credsutil/credsutil.go b/plugins/helper/database/credsutil/credsutil.go index 7f388a0f7678..bc35617ac215 100644 --- a/plugins/helper/database/credsutil/credsutil.go +++ b/plugins/helper/database/credsutil/credsutil.go @@ -2,7 +2,7 @@ package credsutil import "time" -// CredentialsProducer can be used as an embeded interface in the DatabaseType +// CredentialsProducer can be used as an embeded interface in the Database // definition. It implements the methods for generating user information for a // particular database type and is used in all the builtin database types. type CredentialsProducer interface { From 194695f1fa1cbbd65d77e74301cb49ba793a4861 Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Mon, 24 Apr 2017 14:03:48 -0700 Subject: [PATCH 096/152] Don't uppercase ErrorResponses --- builtin/logical/database/path_config_connection.go | 8 ++++---- builtin/logical/database/path_role_create.go | 2 +- builtin/logical/database/path_roles.go | 8 ++++---- 3 files changed, 9 insertions(+), 9 deletions(-) diff --git a/builtin/logical/database/path_config_connection.go b/builtin/logical/database/path_config_connection.go index f154ae1643d2..965364dc5252 100644 --- a/builtin/logical/database/path_config_connection.go +++ b/builtin/logical/database/path_config_connection.go @@ -12,8 +12,8 @@ import ( ) var ( - respErrEmptyPluginName = logical.ErrorResponse("Empty plugin name") - respErrEmptyName = logical.ErrorResponse("Empty name attribute given") + respErrEmptyPluginName = logical.ErrorResponse("empty plugin name") + respErrEmptyName = logical.ErrorResponse("empty name attribute given") ) // DatabaseConfig is used by the Factory function to configure a Database @@ -199,13 +199,13 @@ func (b *databaseBackend) connectionWriteHandler() framework.OperationFunc { db, err := dbplugin.PluginFactory(config.PluginName, b.System(), b.logger) if err != nil { - return logical.ErrorResponse(fmt.Sprintf("Error creating database object: %s", err)), nil + return logical.ErrorResponse(fmt.Sprintf("error creating database object: %s", err)), nil } err = db.Initialize(config.ConnectionDetails, verifyConnection) if err != nil { db.Close() - return logical.ErrorResponse(fmt.Sprintf("Error creating database object: %s", err)), nil + return logical.ErrorResponse(fmt.Sprintf("error creating database object: %s", err)), nil } // Grab the mutex lock diff --git a/builtin/logical/database/path_role_create.go b/builtin/logical/database/path_role_create.go index 631802dff643..a8da211f2ee7 100644 --- a/builtin/logical/database/path_role_create.go +++ b/builtin/logical/database/path_role_create.go @@ -38,7 +38,7 @@ func (b *databaseBackend) pathRoleCreateRead() framework.OperationFunc { return nil, err } if role == nil { - return logical.ErrorResponse(fmt.Sprintf("Unknown role: %s", name)), nil + return logical.ErrorResponse(fmt.Sprintf("unknown role: %s", name)), nil } dbConfig, err := b.DatabaseConfig(req.Storage, role.DBName) diff --git a/builtin/logical/database/path_roles.go b/builtin/logical/database/path_roles.go index b3393b1ba884..e85b123dceff 100644 --- a/builtin/logical/database/path_roles.go +++ b/builtin/logical/database/path_roles.go @@ -136,12 +136,12 @@ func (b *databaseBackend) pathRoleCreate() framework.OperationFunc { return func(req *logical.Request, data *framework.FieldData) (*logical.Response, error) { name := data.Get("name").(string) if name == "" { - return logical.ErrorResponse("Empty role name attribute given"), nil + return logical.ErrorResponse("empty role name attribute given"), nil } dbName := data.Get("db_name").(string) if dbName == "" { - return logical.ErrorResponse("Empty database name attribute given"), nil + return logical.ErrorResponse("empty database name attribute given"), nil } // Get statements @@ -157,12 +157,12 @@ func (b *databaseBackend) pathRoleCreate() framework.OperationFunc { defaultTTL, err := time.ParseDuration(defaultTTLRaw) if err != nil { return logical.ErrorResponse(fmt.Sprintf( - "Invalid default_ttl: %s", err)), nil + "invalid default_ttl: %s", err)), nil } maxTTL, err := time.ParseDuration(maxTTLRaw) if err != nil { return logical.ErrorResponse(fmt.Sprintf( - "Invalid max_ttl: %s", err)), nil + "invalid max_ttl: %s", err)), nil } statements := dbplugin.Statements{ From 1971d65ea3e99f9c34bd56e351a7cf2964a72c8d Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Mon, 24 Apr 2017 16:20:20 -0700 Subject: [PATCH 097/152] Only run Abs on the plugin directory if it's set --- vault/core.go | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/vault/core.go b/vault/core.go index 260ce096e08e..f3ed06d3896d 100644 --- a/vault/core.go +++ b/vault/core.go @@ -466,9 +466,11 @@ func NewCore(conf *CoreConfig) (*Core, error) { } var err error - c.pluginDirectory, err = filepath.Abs(conf.PluginDirectory) - if err != nil { - return nil, fmt.Errorf("core setup failed: %v", err) + if conf.PluginDirectory != "" { + c.pluginDirectory, err = filepath.Abs(conf.PluginDirectory) + if err != nil { + return nil, fmt.Errorf("core setup failed, could not verify plugin directory: %v", err) + } } // Construct a new AES-GCM barrier From 57f78c4cd5556d1d0311feff3b27a80ad4c23dbb Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Mon, 24 Apr 2017 18:31:27 -0700 Subject: [PATCH 098/152] return a 404 when no plugin is found --- vault/dynamic_system_view.go | 11 ++++++++++- vault/logical_system.go | 3 +++ vault/plugin_catalog.go | 2 +- 3 files changed, 14 insertions(+), 2 deletions(-) diff --git a/vault/dynamic_system_view.go b/vault/dynamic_system_view.go index edac20140251..3844b46bfa65 100644 --- a/vault/dynamic_system_view.go +++ b/vault/dynamic_system_view.go @@ -1,6 +1,7 @@ package vault import ( + "fmt" "time" "github.com/hashicorp/vault/helper/consts" @@ -120,7 +121,15 @@ func (d dynamicSystemView) ResponseWrapData(data map[string]interface{}, ttl tim // LookupPlugin looks for a plugin with the given name in the plugin catalog. It // returns a PluginRunner or an error if no plugin was found. func (d dynamicSystemView) LookupPlugin(name string) (*pluginutil.PluginRunner, error) { - return d.core.pluginCatalog.Get(name) + r, err := d.core.pluginCatalog.Get(name) + if err != nil { + return nil, err + } + if r == nil { + return nil, fmt.Errorf("no plugin found with name: %s", name) + } + + return r, nil } // MlockEnabled returns the configuration setting for enabling mlock on plugins. diff --git a/vault/logical_system.go b/vault/logical_system.go index 109100090c4f..4dd66f8147e6 100644 --- a/vault/logical_system.go +++ b/vault/logical_system.go @@ -815,6 +815,9 @@ func (b *SystemBackend) handlePluginCatalogRead(req *logical.Request, d *framewo if err != nil { return nil, err } + if plugin == nil { + return nil, nil + } return &logical.Response{ Data: map[string]interface{}{ diff --git a/vault/plugin_catalog.go b/vault/plugin_catalog.go index 5d88873b3189..095d81b1e4f5 100644 --- a/vault/plugin_catalog.go +++ b/vault/plugin_catalog.go @@ -71,7 +71,7 @@ func (c *PluginCatalog) Get(name string) (*pluginutil.PluginRunner, error) { }, nil } - return nil, fmt.Errorf("no plugin found with name: %s", name) + return nil, nil } // Set registers a new external plugin with the catalog, or updates an existing From 630962bc9677c1a0c218df527d9994c6ff1a5e94 Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Mon, 24 Apr 2017 21:24:19 -0700 Subject: [PATCH 099/152] Update test to reflect the correct read response --- vault/logical_system_test.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vault/logical_system_test.go b/vault/logical_system_test.go index 7bedf7cd6645..aa2ce449a1ca 100644 --- a/vault/logical_system_test.go +++ b/vault/logical_system_test.go @@ -1198,7 +1198,7 @@ func TestSystemBackend_PluginCatalog_CRUD(t *testing.T) { req = logical.TestRequest(t, logical.ReadOperation, "plugins/catalog/test-plugin") resp, err = b.HandleRequest(req) - if err == nil { - t.Fatalf("expected error, plugin not deleted correctly") + if resp != nil || err != nil { + t.Fatalf("expected nil response, plugin not deleted correctly got resp: %v, err: %v", resp, err) } } From 6741811407af903cab8aabb38943c84eb8b7d619 Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Tue, 25 Apr 2017 10:24:19 -0700 Subject: [PATCH 100/152] Update logging to new structure --- .../database/dbplugin/databasemiddleware.go | 20 +++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/builtin/logical/database/dbplugin/databasemiddleware.go b/builtin/logical/database/dbplugin/databasemiddleware.go index 9ab35b740fee..13591e51628d 100644 --- a/builtin/logical/database/dbplugin/databasemiddleware.go +++ b/builtin/logical/database/dbplugin/databasemiddleware.go @@ -25,10 +25,10 @@ func (mw *databaseTracingMiddleware) Type() (string, error) { func (mw *databaseTracingMiddleware) CreateUser(statements Statements, usernamePrefix string, expiration time.Time) (username string, password string, err error) { if mw.logger.IsTrace() { defer func(then time.Time) { - mw.logger.Trace("database/CreateUser: finished", "type", mw.typeStr, "err", err, "took", time.Since(then)) + mw.logger.Trace("database", "operation", "CreateUser", "status", "finished", "type", mw.typeStr, "err", err, "took", time.Since(then)) }(time.Now()) - mw.logger.Trace("database/CreateUser: starting", "type", mw.typeStr) + mw.logger.Trace("database", "operation", "CreateUser", "status", "started", "type", mw.typeStr) } return mw.next.CreateUser(statements, usernamePrefix, expiration) } @@ -36,10 +36,10 @@ func (mw *databaseTracingMiddleware) CreateUser(statements Statements, usernameP func (mw *databaseTracingMiddleware) RenewUser(statements Statements, username string, expiration time.Time) (err error) { if mw.logger.IsTrace() { defer func(then time.Time) { - mw.logger.Trace("database/RenewUser: finished", "type", mw.typeStr, "err", err, "took", time.Since(then)) + mw.logger.Trace("database", "operation", "RenewUser", "status", "finished", "type", mw.typeStr, "err", err, "took", time.Since(then)) }(time.Now()) - mw.logger.Trace("database/RenewUser: starting", "type", mw.typeStr) + mw.logger.Trace("database", "operation", "RenewUser", "status", "started", mw.typeStr) } return mw.next.RenewUser(statements, username, expiration) } @@ -47,10 +47,10 @@ func (mw *databaseTracingMiddleware) RenewUser(statements Statements, username s func (mw *databaseTracingMiddleware) RevokeUser(statements Statements, username string) (err error) { if mw.logger.IsTrace() { defer func(then time.Time) { - mw.logger.Trace("database/RevokeUser: finished", "type", mw.typeStr, "err", err, "took", time.Since(then)) + mw.logger.Trace("database", "operation", "RevokeUser", "status", "finished", "type", mw.typeStr, "err", err, "took", time.Since(then)) }(time.Now()) - mw.logger.Trace("database/RevokeUser: starting", "type", mw.typeStr) + mw.logger.Trace("database", "operation", "RevokeUser", "status", "started", "type", mw.typeStr) } return mw.next.RevokeUser(statements, username) } @@ -58,10 +58,10 @@ func (mw *databaseTracingMiddleware) RevokeUser(statements Statements, username func (mw *databaseTracingMiddleware) Initialize(conf map[string]interface{}, verifyConnection bool) (err error) { if mw.logger.IsTrace() { defer func(then time.Time) { - mw.logger.Trace("database/Initialize: finished", "type", mw.typeStr, "verify", verifyConnection, "err", err, "took", time.Since(then)) + mw.logger.Trace("database", "operation", "Initialize", "status", "finished", "type", mw.typeStr, "verify", verifyConnection, "err", err, "took", time.Since(then)) }(time.Now()) - mw.logger.Trace("database/Initialize: starting", "type", mw.typeStr) + mw.logger.Trace("database", "operation", "Initialize", "status", "started", "type", mw.typeStr) } return mw.next.Initialize(conf, verifyConnection) } @@ -69,10 +69,10 @@ func (mw *databaseTracingMiddleware) Initialize(conf map[string]interface{}, ver func (mw *databaseTracingMiddleware) Close() (err error) { if mw.logger.IsTrace() { defer func(then time.Time) { - mw.logger.Trace("database/Close: finished", "type", mw.typeStr, "err", err, "took", time.Since(then)) + mw.logger.Trace("database", "operation", "Close", "status", "finished", "type", mw.typeStr, "err", err, "took", time.Since(then)) }(time.Now()) - mw.logger.Trace("database/Close: starting", "type", mw.typeStr) + mw.logger.Trace("database", "operation", "Close", "status", "started", "type", mw.typeStr) } return mw.next.Close() } From 22612adefcee547311e5789e60ba5723e7eb24d2 Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Tue, 25 Apr 2017 10:26:23 -0700 Subject: [PATCH 101/152] Use TypeCommaStringSlice for allowed_roles --- builtin/logical/database/path_config_connection.go | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/builtin/logical/database/path_config_connection.go b/builtin/logical/database/path_config_connection.go index 965364dc5252..557d3f3cb9c1 100644 --- a/builtin/logical/database/path_config_connection.go +++ b/builtin/logical/database/path_config_connection.go @@ -98,10 +98,10 @@ func pathConfigurePluginConnection(b *databaseBackend) *framework.Path { }, "allowed_roles": &framework.FieldSchema{ - Type: framework.TypeString, - Description: `Comma separated list of the role names allowed to - get creds from this database connection. If not set all roles - are allowed.`, + Type: framework.TypeCommaStringSlice, + Description: `Comma separated string or array of the role names + allowed to get creds from this database connection. If not set + all roles are allowed.`, }, }, From 58b0bbd47793714a1c720a7aab5341d3c9a02dfd Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Tue, 25 Apr 2017 10:39:17 -0700 Subject: [PATCH 102/152] Rename path_role_create to path_creds_create --- builtin/logical/database/backend.go | 2 +- .../{path_role_create.go => path_creds_create.go} | 15 +++++++-------- 2 files changed, 8 insertions(+), 9 deletions(-) rename builtin/logical/database/{path_role_create.go => path_creds_create.go} (84%) diff --git a/builtin/logical/database/backend.go b/builtin/logical/database/backend.go index 7d6ffe9c9873..e8cf98ebbdfc 100644 --- a/builtin/logical/database/backend.go +++ b/builtin/logical/database/backend.go @@ -27,7 +27,7 @@ func Backend(conf *logical.BackendConfig) *databaseBackend { pathConfigurePluginConnection(&b), pathListRoles(&b), pathRoles(&b), - pathRoleCreate(&b), + pathCredsCreate(&b), pathResetConnection(&b), }, diff --git a/builtin/logical/database/path_role_create.go b/builtin/logical/database/path_creds_create.go similarity index 84% rename from builtin/logical/database/path_role_create.go rename to builtin/logical/database/path_creds_create.go index a8da211f2ee7..341c61d67c09 100644 --- a/builtin/logical/database/path_role_create.go +++ b/builtin/logical/database/path_creds_create.go @@ -9,7 +9,7 @@ import ( "github.com/hashicorp/vault/logical/framework" ) -func pathRoleCreate(b *databaseBackend) *framework.Path { +func pathCredsCreate(b *databaseBackend) *framework.Path { return &framework.Path{ Pattern: "creds/" + framework.GenericNameRegex("name"), Fields: map[string]*framework.FieldSchema{ @@ -20,15 +20,15 @@ func pathRoleCreate(b *databaseBackend) *framework.Path { }, Callbacks: map[logical.Operation]framework.OperationFunc{ - logical.ReadOperation: b.pathRoleCreateRead(), + logical.ReadOperation: b.pathCredsCreateRead(), }, - HelpSynopsis: pathRoleCreateReadHelpSyn, - HelpDescription: pathRoleCreateReadHelpDesc, + HelpSynopsis: pathCredsCreateReadHelpSyn, + HelpDescription: pathCredsCreateReadHelpDesc, } } -func (b *databaseBackend) pathRoleCreateRead() framework.OperationFunc { +func (b *databaseBackend) pathCredsCreateRead() framework.OperationFunc { return func(req *logical.Request, data *framework.FieldData) (*logical.Response, error) { name := data.Get("name").(string) @@ -58,7 +58,6 @@ func (b *databaseBackend) pathRoleCreateRead() framework.OperationFunc { // Get the Database object db, err := b.getOrCreateDBObj(req.Storage, role.DBName) if err != nil { - // TODO: return a resp error instead? return nil, fmt.Errorf("cound not retrieve db with name: %s, got error: %s", role.DBName, err) } @@ -82,11 +81,11 @@ func (b *databaseBackend) pathRoleCreateRead() framework.OperationFunc { } } -const pathRoleCreateReadHelpSyn = ` +const pathCredsCreateReadHelpSyn = ` Request database credentials for a certain role. ` -const pathRoleCreateReadHelpDesc = ` +const pathCredsCreateReadHelpDesc = ` This path reads database credentials for a certain role. The database credentials will be generated on demand and will be automatically revoked when the lease is up. From e18757628c70fa2a435ed014dd33e44fcb55b7b3 Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Tue, 25 Apr 2017 11:11:10 -0700 Subject: [PATCH 103/152] Update the connection details data and fix allowedRoles --- builtin/logical/database/path_config_connection.go | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/builtin/logical/database/path_config_connection.go b/builtin/logical/database/path_config_connection.go index 557d3f3cb9c1..7c175848f2c6 100644 --- a/builtin/logical/database/path_config_connection.go +++ b/builtin/logical/database/path_config_connection.go @@ -6,7 +6,6 @@ import ( "github.com/fatih/structs" "github.com/hashicorp/vault/builtin/logical/database/dbplugin" - "github.com/hashicorp/vault/helper/strutil" "github.com/hashicorp/vault/logical" "github.com/hashicorp/vault/logical/framework" ) @@ -187,9 +186,14 @@ func (b *databaseBackend) connectionWriteHandler() framework.OperationFunc { verifyConnection := data.Get("verify_connection").(bool) - // Pasrse and dedupe allowed roles from a comma separated string. - allowedRolesRaw := data.Get("allowed_roles").(string) - allowedRoles := strutil.ParseDedupAndSortStrings(allowedRolesRaw, ",") + allowedRoles := data.Get("allowed_roles").([]string) + + // Remove these entries from the data before we store it keyed under + // ConnectionDetails. + delete(data.Raw, "name") + delete(data.Raw, "plugin_name") + delete(data.Raw, "allowed_roles") + delete(data.Raw, "verify_connection") config := &DatabaseConfig{ ConnectionDetails: data.Raw, From 6131bdd3b9ed5f7a5af48907e777eea883dd3fcd Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Tue, 25 Apr 2017 11:48:24 -0700 Subject: [PATCH 104/152] Default deny when allowed roles is empty --- builtin/logical/database/backend_test.go | 84 +++++++++++++++++-- .../database/path_config_connection.go | 4 +- builtin/logical/database/path_creds_create.go | 2 +- 3 files changed, 80 insertions(+), 10 deletions(-) diff --git a/builtin/logical/database/backend_test.go b/builtin/logical/database/backend_test.go index 2ece767fcd3e..08317cbdc42f 100644 --- a/builtin/logical/database/backend_test.go +++ b/builtin/logical/database/backend_test.go @@ -113,6 +113,7 @@ func TestBackend_config_connection(t *testing.T) { "connection_url": "sample_connection_url", "plugin_name": "postgresql-database-plugin", "verify_connection": false, + "allowed_roles": []string{"*"}, } configReq := &logical.Request{ @@ -127,9 +128,11 @@ func TestBackend_config_connection(t *testing.T) { } expected := map[string]interface{}{ - "plugin_name": "postgresql-database-plugin", - "connection_details": configData, - "allowed_roles": []string{}, + "plugin_name": "postgresql-database-plugin", + "connection_details": map[string]interface{}{ + "connection_url": "sample_connection_url", + }, + "allowed_roles": []string{"*"}, } configReq.Operation = logical.ReadOperation resp, err = b.HandleRequest(configReq) @@ -164,6 +167,7 @@ func TestBackend_basic(t *testing.T) { data := map[string]interface{}{ "connection_url": connURL, "plugin_name": "postgresql-database-plugin", + "allowed_roles": []string{"plugin-role-test"}, } req := &logical.Request{ Operation: logical.UpdateOperation, @@ -290,6 +294,7 @@ func TestBackend_connectionCrud(t *testing.T) { data = map[string]interface{}{ "connection_url": connURL, "plugin_name": "postgresql-database-plugin", + "allowed_roles": []string{"plugin-role-test"}, } req = &logical.Request{ Operation: logical.UpdateOperation, @@ -304,9 +309,11 @@ func TestBackend_connectionCrud(t *testing.T) { // Read connection expected := map[string]interface{}{ - "plugin_name": "postgresql-database-plugin", - "connection_details": data, - "allowed_roles": []string{}, + "plugin_name": "postgresql-database-plugin", + "connection_details": map[string]interface{}{ + "connection_url": connURL, + }, + "allowed_roles": []string{"plugin-role-test"}, } req.Operation = logical.ReadOperation resp, err = b.HandleRequest(req) @@ -506,7 +513,6 @@ func TestBackend_allowedRoles(t *testing.T) { data := map[string]interface{}{ "connection_url": connURL, "plugin_name": "postgresql-database-plugin", - "allowed_roles": "allow, allowed", } req := &logical.Request{ Operation: logical.UpdateOperation, @@ -567,6 +573,70 @@ func TestBackend_allowedRoles(t *testing.T) { t.Fatalf("expected error to be:%s got:%#v\n", logical.ErrPermissionDenied, err) } + // update connection with * allowed roles connection + data = map[string]interface{}{ + "connection_url": connURL, + "plugin_name": "postgresql-database-plugin", + "allowed_roles": "*", + } + req = &logical.Request{ + Operation: logical.UpdateOperation, + Path: "config/plugin-test", + Storage: config.StorageView, + Data: data, + } + resp, err = b.HandleRequest(req) + if err != nil || (resp != nil && resp.IsError()) { + t.Fatalf("err:%s resp:%#v\n", err, resp) + } + + // Get creds, should work. + data = map[string]interface{}{} + req = &logical.Request{ + Operation: logical.ReadOperation, + Path: "creds/allowed", + Storage: config.StorageView, + Data: data, + } + credsResp, err = b.HandleRequest(req) + if err != nil || (credsResp != nil && credsResp.IsError()) { + t.Fatalf("err:%s resp:%#v\n", err, credsResp) + } + + if !testCredsExist(t, credsResp, connURL) { + t.Fatalf("Creds should exist") + } + + // update connection with allowed roles + data = map[string]interface{}{ + "connection_url": connURL, + "plugin_name": "postgresql-database-plugin", + "allowed_roles": "allow, allowed", + } + req = &logical.Request{ + Operation: logical.UpdateOperation, + Path: "config/plugin-test", + Storage: config.StorageView, + Data: data, + } + resp, err = b.HandleRequest(req) + if err != nil || (resp != nil && resp.IsError()) { + t.Fatalf("err:%s resp:%#v\n", err, resp) + } + + // Get creds from denied role, should fail + data = map[string]interface{}{} + req = &logical.Request{ + Operation: logical.ReadOperation, + Path: "creds/denied", + Storage: config.StorageView, + Data: data, + } + credsResp, err = b.HandleRequest(req) + if err != logical.ErrPermissionDenied { + t.Fatalf("expected error to be:%s got:%#v\n", logical.ErrPermissionDenied, err) + } + // Get creds from allowed role, should work. data = map[string]interface{}{} req = &logical.Request{ diff --git a/builtin/logical/database/path_config_connection.go b/builtin/logical/database/path_config_connection.go index 7c175848f2c6..f52cfec59cc6 100644 --- a/builtin/logical/database/path_config_connection.go +++ b/builtin/logical/database/path_config_connection.go @@ -99,8 +99,8 @@ func pathConfigurePluginConnection(b *databaseBackend) *framework.Path { "allowed_roles": &framework.FieldSchema{ Type: framework.TypeCommaStringSlice, Description: `Comma separated string or array of the role names - allowed to get creds from this database connection. If not set - all roles are allowed.`, + allowed to get creds from this database connection. If empty no + roles are allowed. If "*" all roles are allowed.`, }, }, diff --git a/builtin/logical/database/path_creds_create.go b/builtin/logical/database/path_creds_create.go index 341c61d67c09..9bbaceb54b16 100644 --- a/builtin/logical/database/path_creds_create.go +++ b/builtin/logical/database/path_creds_create.go @@ -48,7 +48,7 @@ func (b *databaseBackend) pathCredsCreateRead() framework.OperationFunc { // If role name isn't in the database's allowed roles, send back a // permission denied. - if len(dbConfig.AllowedRoles) > 0 && !strutil.StrListContains(dbConfig.AllowedRoles, name) { + if !strutil.StrListContains(dbConfig.AllowedRoles, "*") && !strutil.StrListContains(dbConfig.AllowedRoles, name) { return nil, logical.ErrPermissionDenied } From 37aacba0dac960523e2244e7411217d06a8d7dd4 Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Wed, 26 Apr 2017 10:02:37 -0700 Subject: [PATCH 105/152] Change ttl types to TypeDurationSecond --- builtin/logical/database/path_roles.go | 26 ++++++++------------------ 1 file changed, 8 insertions(+), 18 deletions(-) diff --git a/builtin/logical/database/path_roles.go b/builtin/logical/database/path_roles.go index e85b123dceff..c81261804641 100644 --- a/builtin/logical/database/path_roles.go +++ b/builtin/logical/database/path_roles.go @@ -1,7 +1,6 @@ package database import ( - "fmt" "time" "github.com/hashicorp/vault/builtin/logical/database/dbplugin" @@ -65,12 +64,12 @@ func pathRoles(b *databaseBackend) *framework.Path { }, "default_ttl": { - Type: framework.TypeString, + Type: framework.TypeDurationSecond, Description: "Default ttl for role.", }, "max_ttl": { - Type: framework.TypeString, + Type: framework.TypeDurationSecond, Description: "Maximum time a credential is valid for", }, }, @@ -114,8 +113,8 @@ func (b *databaseBackend) pathRoleRead() framework.OperationFunc { "revocation_statements": role.Statements.RevocationStatements, "rollback_statements": role.Statements.RollbackStatements, "renew_statements": role.Statements.RenewStatements, - "default_ttl": role.DefaultTTL.String(), - "max_ttl": role.MaxTTL.String(), + "default_ttl": role.DefaultTTL.Seconds(), + "max_ttl": role.MaxTTL.Seconds(), }, }, nil } @@ -151,19 +150,10 @@ func (b *databaseBackend) pathRoleCreate() framework.OperationFunc { renewStmts := data.Get("renew_statements").(string) // Get TTLs - defaultTTLRaw := data.Get("default_ttl").(string) - maxTTLRaw := data.Get("max_ttl").(string) - - defaultTTL, err := time.ParseDuration(defaultTTLRaw) - if err != nil { - return logical.ErrorResponse(fmt.Sprintf( - "invalid default_ttl: %s", err)), nil - } - maxTTL, err := time.ParseDuration(maxTTLRaw) - if err != nil { - return logical.ErrorResponse(fmt.Sprintf( - "invalid max_ttl: %s", err)), nil - } + defaultTTLRaw := data.Get("default_ttl").(int) + maxTTLRaw := data.Get("max_ttl").(int) + defaultTTL := time.Duration(defaultTTLRaw) * time.Second + maxTTL := time.Duration(maxTTLRaw) * time.Second statements := dbplugin.Statements{ CreationStatements: creationStmts, From d8dbfc6a0cc073806886eb0bd5dd002aac057b3b Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Wed, 26 Apr 2017 10:29:16 -0700 Subject: [PATCH 106/152] Update the error messages for renew and revoke --- builtin/logical/database/secret_creds.go | 24 ++++-------------------- 1 file changed, 4 insertions(+), 20 deletions(-) diff --git a/builtin/logical/database/secret_creds.go b/builtin/logical/database/secret_creds.go index ffc59cf3fec3..2704eb287c04 100644 --- a/builtin/logical/database/secret_creds.go +++ b/builtin/logical/database/secret_creds.go @@ -38,7 +38,7 @@ func (b *databaseBackend) secretCredsRenew() framework.OperationFunc { return nil, err } if role == nil { - return nil, fmt.Errorf("could not find role with name: %s", req.Secret.InternalData["role"]) + return nil, fmt.Errorf("error during renew: could not find role with name %s", req.Secret.InternalData["role"]) } f := framework.LeaseExtend(role.DefaultTTL, role.MaxTTL, b.System()) @@ -54,7 +54,7 @@ func (b *databaseBackend) secretCredsRenew() framework.OperationFunc { // Get our connection db, err := b.getOrCreateDBObj(req.Storage, role.DBName) if err != nil { - return nil, fmt.Errorf("could not find connection with name %s, got err: %s", role.DBName, err) + return nil, fmt.Errorf("error during renew: %s", err) } // Make sure we increase the VALID UNTIL endpoint for this user. @@ -90,25 +90,9 @@ func (b *databaseBackend) secretCredsRevoke() framework.OperationFunc { return nil, err } if role == nil { - return nil, fmt.Errorf("could not find role with name: %s", req.Secret.InternalData["role"]) + return nil, fmt.Errorf("error during revoke: could not find role with name %s", req.Secret.InternalData["role"]) } - /* TODO: think about how to handle this case. - if !ok { - role, err := b.Role(req.Storage, roleNameRaw.(string)) - if err != nil { - return nil, err - } - if role == nil { - if resp == nil { - resp = &logical.Response{} - } - resp.AddWarning(fmt.Sprintf("Role %q cannot be found. Using default revocation SQL.", roleNameRaw.(string))) - } else { - revocationSQL = role.RevocationStatement - } - }*/ - // Grab the read lock b.Lock() defer b.Unlock() @@ -116,7 +100,7 @@ func (b *databaseBackend) secretCredsRevoke() framework.OperationFunc { // Get our connection db, err := b.getOrCreateDBObj(req.Storage, role.DBName) if err != nil { - return nil, fmt.Errorf("could not find database with name: %s, got error: %s", role.DBName, err) + return nil, fmt.Errorf("error during revoke: %s", err) } err = db.RevokeUser(role.Statements, username) From dc9740d97ac31f5ca424066665ab7d2403a762ee Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Wed, 26 Apr 2017 10:34:45 -0700 Subject: [PATCH 107/152] Add mssql builtin plugin type --- helper/builtinplugins/builtin.go | 2 ++ plugins/database/mssql/mssql.go | 11 +++++++---- 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/helper/builtinplugins/builtin.go b/helper/builtinplugins/builtin.go index b61a51710d22..c20a92603dae 100644 --- a/helper/builtinplugins/builtin.go +++ b/helper/builtinplugins/builtin.go @@ -1,6 +1,7 @@ package builtinplugins import ( + "github.com/hashicorp/vault/plugins/database/mssql" "github.com/hashicorp/vault/plugins/database/mysql" "github.com/hashicorp/vault/plugins/database/postgresql" ) @@ -10,6 +11,7 @@ type BuiltinFactory func() (interface{}, error) var plugins map[string]BuiltinFactory = map[string]BuiltinFactory{ "mysql-database-plugin": mysql.New, "postgresql-database-plugin": postgresql.New, + "mssql-database-plugin": mssql.New, } func Get(name string) (BuiltinFactory, bool) { diff --git a/plugins/database/mssql/mssql.go b/plugins/database/mssql/mssql.go index 54f2a9711abc..48da8ff08b1a 100644 --- a/plugins/database/mssql/mssql.go +++ b/plugins/database/mssql/mssql.go @@ -21,7 +21,7 @@ type MSSQL struct { credsutil.CredentialsProducer } -func New() *MSSQL { +func New() (interface{}, error) { connProducer := &connutil.SQLConnectionProducer{} connProducer.Type = msSQLTypeName @@ -35,14 +35,17 @@ func New() *MSSQL { CredentialsProducer: credsProducer, } - return dbType + return dbType, nil } // Run instantiates a MSSQL object, and runs the RPC server for the plugin func Run() error { - dbType := New() + dbType, err := New() + if err != nil { + return err + } - dbplugin.NewPluginServer(dbType) + dbplugin.NewPluginServer(dbType.(*MSSQL)) return nil } From cb13786f0ab639262b71f1c98b6a5d1a2cdff42d Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Wed, 26 Apr 2017 10:52:10 -0700 Subject: [PATCH 108/152] Fix MSSQL test --- plugins/database/mssql/mssql_test.go | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/plugins/database/mssql/mssql_test.go b/plugins/database/mssql/mssql_test.go index 512033bd76ee..0dc18cb3e2a3 100644 --- a/plugins/database/mssql/mssql_test.go +++ b/plugins/database/mssql/mssql_test.go @@ -27,7 +27,8 @@ func TestMSSQL_Initialize(t *testing.T) { "connection_url": connURL, } - db := New() + dbRaw, _ := New() + db := dbRaw.(*MSSQL) err := db.Initialize(connectionDetails, true) if err != nil { @@ -55,7 +56,8 @@ func TestMSSQL_CreateUser(t *testing.T) { "connection_url": connURL, } - db := New() + dbRaw, _ := New() + db := dbRaw.(*MSSQL) err := db.Initialize(connectionDetails, true) if err != nil { t.Fatalf("err: %s", err) @@ -91,7 +93,8 @@ func TestMSSQL_RevokeUser(t *testing.T) { "connection_url": connURL, } - db := New() + dbRaw, _ := New() + db := dbRaw.(*MSSQL) err := db.Initialize(connectionDetails, true) if err != nil { t.Fatalf("err: %s", err) From 6b050470fdee424a8934383cd6905e484128e226 Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Wed, 26 Apr 2017 15:23:14 -0700 Subject: [PATCH 109/152] Update to a RWMutex --- builtin/logical/database/backend.go | 20 +++++---- .../database/path_config_connection.go | 2 +- builtin/logical/database/path_creds_create.go | 20 ++++++--- builtin/logical/database/secret_creds.go | 42 ++++++++++++++----- 4 files changed, 58 insertions(+), 26 deletions(-) diff --git a/builtin/logical/database/backend.go b/builtin/logical/database/backend.go index e8cf98ebbdfc..2aff47375d8e 100644 --- a/builtin/logical/database/backend.go +++ b/builtin/logical/database/backend.go @@ -50,7 +50,7 @@ type databaseBackend struct { logger log.Logger *framework.Backend - sync.Mutex + sync.RWMutex } // resetAllDBs closes all connections from all database types @@ -66,21 +66,23 @@ func (b *databaseBackend) closeAllDBs() { } // This function is used to retrieve a database object either from the cached -// connection map or by using the database config in storage. The caller of this -// function needs to hold the backend's lock. -func (b *databaseBackend) getOrCreateDBObj(s logical.Storage, name string) (dbplugin.Database, error) { - // if the object already is built and cached, return it +// connection map. The caller of this function needs to hold the backend's read +// lock. +func (b *databaseBackend) getDBObj(name string) (dbplugin.Database, bool) { db, ok := b.connections[name] - if ok { - return db, nil - } + return db, ok +} +// This function creates a new db object from the stored configuration and +// caches it in the connections map. The caller of this function needs to hold +// the backend's write lock +func (b *databaseBackend) createDBObj(s logical.Storage, name string) (dbplugin.Database, error) { config, err := b.DatabaseConfig(s, name) if err != nil { return nil, err } - db, err = dbplugin.PluginFactory(config.PluginName, b.System(), b.logger) + db, err := dbplugin.PluginFactory(config.PluginName, b.System(), b.logger) if err != nil { return nil, err } diff --git a/builtin/logical/database/path_config_connection.go b/builtin/logical/database/path_config_connection.go index f52cfec59cc6..39eb3d000819 100644 --- a/builtin/logical/database/path_config_connection.go +++ b/builtin/logical/database/path_config_connection.go @@ -62,7 +62,7 @@ func (b *databaseBackend) pathConnectionReset() framework.OperationFunc { b.clearConnection(name) // Execute plugin again, we don't need the object so throw away. - _, err := b.getOrCreateDBObj(req.Storage, name) + _, err := b.createDBObj(req.Storage, name) if err != nil { return nil, err } diff --git a/builtin/logical/database/path_creds_create.go b/builtin/logical/database/path_creds_create.go index 9bbaceb54b16..60f0c5e3ed1c 100644 --- a/builtin/logical/database/path_creds_create.go +++ b/builtin/logical/database/path_creds_create.go @@ -52,13 +52,23 @@ func (b *databaseBackend) pathCredsCreateRead() framework.OperationFunc { return nil, logical.ErrPermissionDenied } - b.Lock() - defer b.Unlock() + b.RLock() // Get the Database object - db, err := b.getOrCreateDBObj(req.Storage, role.DBName) - if err != nil { - return nil, fmt.Errorf("cound not retrieve db with name: %s, got error: %s", role.DBName, err) + db, ok := b.getDBObj(role.DBName) + if !ok { + // Upgrade lock + b.RUnlock() + b.Lock() + defer b.Unlock() + + // Create a new DB object + db, err = b.createDBObj(req.Storage, role.DBName) + if err != nil { + return nil, fmt.Errorf("cound not retrieve db with name: %s, got error: %s", role.DBName, err) + } + } else { + defer b.RUnlock() } expiration := time.Now().Add(role.DefaultTTL) diff --git a/builtin/logical/database/secret_creds.go b/builtin/logical/database/secret_creds.go index 2704eb287c04..690b41565e25 100644 --- a/builtin/logical/database/secret_creds.go +++ b/builtin/logical/database/secret_creds.go @@ -48,13 +48,23 @@ func (b *databaseBackend) secretCredsRenew() framework.OperationFunc { } // Grab the read lock - b.Lock() - defer b.Unlock() + b.RLock() - // Get our connection - db, err := b.getOrCreateDBObj(req.Storage, role.DBName) - if err != nil { - return nil, fmt.Errorf("error during renew: %s", err) + // Get the Database object + db, ok := b.getDBObj(role.DBName) + if !ok { + // Upgrade lock + b.RUnlock() + b.Lock() + defer b.Unlock() + + // Create a new DB object + db, err = b.createDBObj(req.Storage, role.DBName) + if err != nil { + return nil, fmt.Errorf("cound not retrieve db with name: %s, got error: %s", role.DBName, err) + } + } else { + defer b.RUnlock() } // Make sure we increase the VALID UNTIL endpoint for this user. @@ -94,13 +104,23 @@ func (b *databaseBackend) secretCredsRevoke() framework.OperationFunc { } // Grab the read lock - b.Lock() - defer b.Unlock() + b.RLock() // Get our connection - db, err := b.getOrCreateDBObj(req.Storage, role.DBName) - if err != nil { - return nil, fmt.Errorf("error during revoke: %s", err) + db, ok := b.getDBObj(role.DBName) + if !ok { + // Upgrade lock + b.RUnlock() + b.Lock() + defer b.Unlock() + + // Create a new DB object + db, err = b.createDBObj(req.Storage, role.DBName) + if err != nil { + return nil, fmt.Errorf("cound not retrieve db with name: %s, got error: %s", role.DBName, err) + } + } else { + defer b.RUnlock() } err = db.RevokeUser(role.Statements, username) From f92d6868a0a5dc9c02eb6446e4cca579d8a8512a Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Wed, 26 Apr 2017 15:55:34 -0700 Subject: [PATCH 110/152] Add an error check to reset a plugin if it is closed --- builtin/logical/database/backend.go | 10 ++++++++++ .../logical/database/path_config_connection.go | 12 ++---------- builtin/logical/database/path_creds_create.go | 10 +++++++--- builtin/logical/database/secret_creds.go | 18 ++++++++++++------ 4 files changed, 31 insertions(+), 19 deletions(-) diff --git a/builtin/logical/database/backend.go b/builtin/logical/database/backend.go index 2aff47375d8e..4bac4b0c36e8 100644 --- a/builtin/logical/database/backend.go +++ b/builtin/logical/database/backend.go @@ -2,6 +2,7 @@ package database import ( "fmt" + "net/rpc" "strings" "sync" @@ -152,6 +153,15 @@ func (b *databaseBackend) clearConnection(name string) { } } +func (b *databaseBackend) closeIfShutdown(name string, err error) { + // Plugin has shutdown, close it so next call can reconnect. + if err == rpc.ErrShutdown { + b.Lock() + b.clearConnection(name) + b.Unlock() + } +} + const backendHelp = ` The database backend supports using many different databases as secret backends, including but not limited to: diff --git a/builtin/logical/database/path_config_connection.go b/builtin/logical/database/path_config_connection.go index 39eb3d000819..4c0863fd7e10 100644 --- a/builtin/logical/database/path_config_connection.go +++ b/builtin/logical/database/path_config_connection.go @@ -216,16 +216,8 @@ func (b *databaseBackend) connectionWriteHandler() framework.OperationFunc { b.Lock() defer b.Unlock() - if _, ok := b.connections[name]; ok { - // Close and remove the old connection - err := b.connections[name].Close() - if err != nil { - db.Close() - return nil, err - } - - delete(b.connections, name) - } + // Close and remove the old connection + b.clearConnection(name) // Save the new connection b.connections[name] = db diff --git a/builtin/logical/database/path_creds_create.go b/builtin/logical/database/path_creds_create.go index 60f0c5e3ed1c..7bc7dfa6fe7e 100644 --- a/builtin/logical/database/path_creds_create.go +++ b/builtin/logical/database/path_creds_create.go @@ -52,7 +52,9 @@ func (b *databaseBackend) pathCredsCreateRead() framework.OperationFunc { return nil, logical.ErrPermissionDenied } + // Grab the read lock b.RLock() + var unlockFunc func() = b.RUnlock // Get the Database object db, ok := b.getDBObj(role.DBName) @@ -60,22 +62,24 @@ func (b *databaseBackend) pathCredsCreateRead() framework.OperationFunc { // Upgrade lock b.RUnlock() b.Lock() - defer b.Unlock() + unlockFunc = b.Unlock // Create a new DB object db, err = b.createDBObj(req.Storage, role.DBName) if err != nil { + unlockFunc() return nil, fmt.Errorf("cound not retrieve db with name: %s, got error: %s", role.DBName, err) } - } else { - defer b.RUnlock() } expiration := time.Now().Add(role.DefaultTTL) // Create the user username, password, err := db.CreateUser(role.Statements, req.DisplayName, expiration) + // Unlock + unlockFunc() if err != nil { + b.closeIfShutdown(role.DBName, err) return nil, err } diff --git a/builtin/logical/database/secret_creds.go b/builtin/logical/database/secret_creds.go index 690b41565e25..c3dfcb973368 100644 --- a/builtin/logical/database/secret_creds.go +++ b/builtin/logical/database/secret_creds.go @@ -49,6 +49,7 @@ func (b *databaseBackend) secretCredsRenew() framework.OperationFunc { // Grab the read lock b.RLock() + var unlockFunc func() = b.RUnlock // Get the Database object db, ok := b.getDBObj(role.DBName) @@ -56,21 +57,23 @@ func (b *databaseBackend) secretCredsRenew() framework.OperationFunc { // Upgrade lock b.RUnlock() b.Lock() - defer b.Unlock() + unlockFunc = b.Unlock // Create a new DB object db, err = b.createDBObj(req.Storage, role.DBName) if err != nil { + unlockFunc() return nil, fmt.Errorf("cound not retrieve db with name: %s, got error: %s", role.DBName, err) } - } else { - defer b.RUnlock() } // Make sure we increase the VALID UNTIL endpoint for this user. if expireTime := resp.Secret.ExpirationTime(); !expireTime.IsZero() { err := db.RenewUser(role.Statements, username, expireTime) + // Unlock + unlockFunc() if err != nil { + b.closeIfShutdown(role.DBName, err) return nil, err } } @@ -105,6 +108,7 @@ func (b *databaseBackend) secretCredsRevoke() framework.OperationFunc { // Grab the read lock b.RLock() + var unlockFunc func() = b.RUnlock // Get our connection db, ok := b.getDBObj(role.DBName) @@ -112,19 +116,21 @@ func (b *databaseBackend) secretCredsRevoke() framework.OperationFunc { // Upgrade lock b.RUnlock() b.Lock() - defer b.Unlock() + unlockFunc = b.Unlock // Create a new DB object db, err = b.createDBObj(req.Storage, role.DBName) if err != nil { + unlockFunc() return nil, fmt.Errorf("cound not retrieve db with name: %s, got error: %s", role.DBName, err) } - } else { - defer b.RUnlock() } err = db.RevokeUser(role.Statements, username) + // Unlock + unlockFunc() if err != nil { + b.closeIfShutdown(role.DBName, err) return nil, err } From 2e2d3827da2ee3088a6b3c2f9db81f9dbd9640ab Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Wed, 26 Apr 2017 16:43:42 -0700 Subject: [PATCH 111/152] Add check to ensure we don't overwrite existing connections --- builtin/logical/database/backend.go | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/builtin/logical/database/backend.go b/builtin/logical/database/backend.go index 4bac4b0c36e8..da8c8384acd2 100644 --- a/builtin/logical/database/backend.go +++ b/builtin/logical/database/backend.go @@ -78,12 +78,17 @@ func (b *databaseBackend) getDBObj(name string) (dbplugin.Database, bool) { // caches it in the connections map. The caller of this function needs to hold // the backend's write lock func (b *databaseBackend) createDBObj(s logical.Storage, name string) (dbplugin.Database, error) { + db, ok := b.connections[name] + if ok { + return db, nil + } + config, err := b.DatabaseConfig(s, name) if err != nil { return nil, err } - db, err := dbplugin.PluginFactory(config.PluginName, b.System(), b.logger) + db, err = dbplugin.PluginFactory(config.PluginName, b.System(), b.logger) if err != nil { return nil, err } From 230a36c5a197724c5da6dd86b3ef02428b4e9e22 Mon Sep 17 00:00:00 2001 From: Calvin Leung Huang Date: Thu, 27 Apr 2017 11:07:52 -0400 Subject: [PATCH 112/152] Update New() func signature and its references --- plugins/database/cassandra/cassandra.go | 11 +++++++---- plugins/database/cassandra/cassandra_test.go | 12 ++++++++---- 2 files changed, 15 insertions(+), 8 deletions(-) diff --git a/plugins/database/cassandra/cassandra.go b/plugins/database/cassandra/cassandra.go index 15df0352e445..24d87a353464 100644 --- a/plugins/database/cassandra/cassandra.go +++ b/plugins/database/cassandra/cassandra.go @@ -24,7 +24,7 @@ type Cassandra struct { credsutil.CredentialsProducer } -func New() *Cassandra { +func New() (interface{}, error) { connProducer := &connutil.CassandraConnectionProducer{} connProducer.Type = cassandraTypeName @@ -35,14 +35,17 @@ func New() *Cassandra { CredentialsProducer: credsProducer, } - return dbType + return dbType, nil } // Run instantiates a MySQL object, and runs the RPC server for the plugin func Run() error { - dbType := New() + dbType, err := New() + if err != nil { + return err + } - dbplugin.NewPluginServer(dbType) + dbplugin.NewPluginServer(dbType.(*Cassandra)) return nil } diff --git a/plugins/database/cassandra/cassandra_test.go b/plugins/database/cassandra/cassandra_test.go index b81c32710c45..9e98ec48f558 100644 --- a/plugins/database/cassandra/cassandra_test.go +++ b/plugins/database/cassandra/cassandra_test.go @@ -80,7 +80,8 @@ func TestCassandra_Initialize(t *testing.T) { "protocol_version": 4, } - db := New() + dbRaw, _ := New() + db := dbRaw.(*Cassandra) connProducer := db.ConnectionProducer.(*connutil.CassandraConnectionProducer) err := db.Initialize(connectionDetails, true) @@ -109,7 +110,8 @@ func TestCassandra_CreateUser(t *testing.T) { "protocol_version": 4, } - db := New() + dbRaw, _ := New() + db := dbRaw.(*Cassandra) err := db.Initialize(connectionDetails, true) if err != nil { t.Fatalf("err: %s", err) @@ -140,7 +142,8 @@ func TestMyCassandra_RenewUser(t *testing.T) { "protocol_version": 4, } - db := New() + dbRaw, _ := New() + db := dbRaw.(*Cassandra) err := db.Initialize(connectionDetails, true) if err != nil { t.Fatalf("err: %s", err) @@ -176,7 +179,8 @@ func TestCassandra_RevokeUser(t *testing.T) { "protocol_version": 4, } - db := New() + dbRaw, _ := New() + db := dbRaw.(*Cassandra) err := db.Initialize(connectionDetails, true) if err != nil { t.Fatalf("err: %s", err) From 766b90976dae8afcdc612b08a876f5861849e17b Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Thu, 27 Apr 2017 22:56:06 -0700 Subject: [PATCH 113/152] If user provides a revocation statement for MSSQL plugin honor it --- plugins/database/mssql/mssql.go | 45 +++++++++++++++++++++++++++++++++ 1 file changed, 45 insertions(+) diff --git a/plugins/database/mssql/mssql.go b/plugins/database/mssql/mssql.go index 48da8ff08b1a..a0d863080861 100644 --- a/plugins/database/mssql/mssql.go +++ b/plugins/database/mssql/mssql.go @@ -142,6 +142,51 @@ func (m *MSSQL) RenewUser(statements dbplugin.Statements, username string, expir // then kill pending connections from that user, and finally drop the user and login from the // database instance. func (m *MSSQL) RevokeUser(statements dbplugin.Statements, username string) error { + if statements.RevocationStatements == "" { + return m.revokeUserDefault(username) + } + + // Get connection + db, err := m.getConnection() + if err != nil { + return err + } + + // Start a transaction + tx, err := db.Begin() + if err != nil { + return err + } + defer tx.Rollback() + + // Execute each query + for _, query := range strutil.ParseArbitraryStringSlice(statements.RevocationStatements, ";") { + query = strings.TrimSpace(query) + if len(query) == 0 { + continue + } + + stmt, err := tx.Prepare(dbutil.QueryHelper(query, map[string]string{ + "name": username, + })) + if err != nil { + return err + } + defer stmt.Close() + if _, err := stmt.Exec(); err != nil { + return err + } + } + + // Commit the transaction + if err := tx.Commit(); err != nil { + return err + } + + return nil +} + +func (m *MSSQL) revokeUserDefault(username string) error { // Get connection db, err := m.getConnection() if err != nil { From 6684e5c91dbda1bbab3b120fa437dccd04a255af Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Thu, 27 Apr 2017 22:59:22 -0700 Subject: [PATCH 114/152] Update username length for MSSQL --- plugins/database/mssql/mssql.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/plugins/database/mssql/mssql.go b/plugins/database/mssql/mssql.go index a0d863080861..b608428e5cf0 100644 --- a/plugins/database/mssql/mssql.go +++ b/plugins/database/mssql/mssql.go @@ -26,8 +26,8 @@ func New() (interface{}, error) { connProducer.Type = msSQLTypeName credsProducer := &credsutil.SQLCredentialsProducer{ - DisplayNameLen: 4, - UsernameLen: 16, + DisplayNameLen: 20, + UsernameLen: 128, } dbType := &MSSQL{ From 445a0e339b87a839d6ccbaee6e3acfc54fda47a2 Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Thu, 27 Apr 2017 23:02:33 -0700 Subject: [PATCH 115/152] Update the username length for postgresql --- plugins/database/postgresql/postgresql.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/plugins/database/postgresql/postgresql.go b/plugins/database/postgresql/postgresql.go index 5781b6c3d166..e90e0f8cbcbf 100644 --- a/plugins/database/postgresql/postgresql.go +++ b/plugins/database/postgresql/postgresql.go @@ -22,8 +22,8 @@ func New() (interface{}, error) { connProducer.Type = postgreSQLTypeName credsProducer := &credsutil.SQLCredentialsProducer{ - DisplayNameLen: 4, - UsernameLen: 16, + DisplayNameLen: 10, + UsernameLen: 63, } dbType := &PostgreSQL{ From f3e7ad7669a81e7314d3580f3d94ffd8badeb1b6 Mon Sep 17 00:00:00 2001 From: Calvin Leung Huang Date: Mon, 1 May 2017 11:27:35 -0400 Subject: [PATCH 116/152] Honor statements for RevokeUser on Cassandra backend, add method comments --- plugins/database/cassandra/cassandra.go | 44 ++++++++++++++----- .../cassandra/test-fixtures/cassandra.yaml | 8 ++-- 2 files changed, 36 insertions(+), 16 deletions(-) diff --git a/plugins/database/cassandra/cassandra.go b/plugins/database/cassandra/cassandra.go index 24d87a353464..bf1cbab92cf4 100644 --- a/plugins/database/cassandra/cassandra.go +++ b/plugins/database/cassandra/cassandra.go @@ -1,11 +1,11 @@ package cassandra import ( - "fmt" "strings" "time" "github.com/gocql/gocql" + multierror "github.com/hashicorp/go-multierror" "github.com/hashicorp/vault/builtin/logical/database/dbplugin" "github.com/hashicorp/vault/helper/strutil" "github.com/hashicorp/vault/plugins/helper/database/connutil" @@ -14,16 +14,18 @@ import ( ) const ( - defaultCreationCQL = `CREATE USER '{{username}}' WITH PASSWORD '{{password}}' NOSUPERUSER;` - defaultRollbackCQL = `DROP USER '{{username}}';` - cassandraTypeName = "cassandra" + defaultUserCreationCQL = `CREATE USER '{{username}}' WITH PASSWORD '{{password}}' NOSUPERUSER;` + defaultUserDeletionCQL = `DROP USER '{{username}}';` + cassandraTypeName = "cassandra" ) +// Cassandra is an implementation of Database interface type Cassandra struct { connutil.ConnectionProducer credsutil.CredentialsProducer } +// New returns a new Cassandra instance func New() (interface{}, error) { connProducer := &connutil.CassandraConnectionProducer{} connProducer.Type = cassandraTypeName @@ -38,7 +40,7 @@ func New() (interface{}, error) { return dbType, nil } -// Run instantiates a MySQL object, and runs the RPC server for the plugin +// Run instantiates a Cassandra object, and runs the RPC server for the plugin func Run() error { dbType, err := New() if err != nil { @@ -50,6 +52,7 @@ func Run() error { return nil } +// Type returns the TypeName for this backend func (c *Cassandra) Type() (string, error) { return cassandraTypeName, nil } @@ -63,6 +66,8 @@ func (c *Cassandra) getConnection() (*gocql.Session, error) { return session.(*gocql.Session), nil } +// CreateUser generates the username/password on the underlying Cassandra secret backend as instructed by +// the CreationStatement provided. func (c *Cassandra) CreateUser(statements dbplugin.Statements, usernamePrefix string, expiration time.Time) (username string, password string, err error) { // Grab the lock c.Lock() @@ -76,11 +81,11 @@ func (c *Cassandra) CreateUser(statements dbplugin.Statements, usernamePrefix st creationCQL := statements.CreationStatements if creationCQL == "" { - creationCQL = defaultCreationCQL + creationCQL = defaultUserCreationCQL } rollbackCQL := statements.RollbackStatements if rollbackCQL == "" { - rollbackCQL = defaultRollbackCQL + rollbackCQL = defaultUserDeletionCQL } username, err = c.GenerateUsername(usernamePrefix) @@ -113,7 +118,6 @@ func (c *Cassandra) CreateUser(statements dbplugin.Statements, usernamePrefix st session.Query(dbutil.QueryHelper(query, map[string]string{ "username": username, - "password": password, })).Exec() } return "", "", err @@ -123,11 +127,13 @@ func (c *Cassandra) CreateUser(statements dbplugin.Statements, usernamePrefix st return username, password, nil } +// RenewUser is not supported on Cassandra, so this is a no-op. func (c *Cassandra) RenewUser(statements dbplugin.Statements, username string, expiration time.Time) error { // NOOP return nil } +// RevokeUser attempts to drop the specified user. func (c *Cassandra) RevokeUser(statements dbplugin.Statements, username string) error { // Grab the lock c.Lock() @@ -138,10 +144,24 @@ func (c *Cassandra) RevokeUser(statements dbplugin.Statements, username string) return err } - err = session.Query(fmt.Sprintf("DROP USER '%s'", username)).Exec() - if err != nil { - return fmt.Errorf("error removing user '%s': %s", username, err) + revocationCQL := statements.RevocationStatements + if revocationCQL == "" { + revocationCQL = defaultUserDeletionCQL } - return nil + var result *multierror.Error + for _, query := range strutil.ParseArbitraryStringSlice(revocationCQL, ";") { + query = strings.TrimSpace(query) + if len(query) == 0 { + continue + } + + err := session.Query(dbutil.QueryHelper(query, map[string]string{ + "username": username, + })).Exec() + + result = multierror.Append(result, err) + } + + return result.ErrorOrNil() } diff --git a/plugins/database/cassandra/test-fixtures/cassandra.yaml b/plugins/database/cassandra/test-fixtures/cassandra.yaml index 5b12c8cf4e69..54f47d34ac62 100644 --- a/plugins/database/cassandra/test-fixtures/cassandra.yaml +++ b/plugins/database/cassandra/test-fixtures/cassandra.yaml @@ -421,7 +421,7 @@ seed_provider: parameters: # seeds is actually a comma-delimited list of addresses. # Ex: ",," - - seeds: "172.17.0.3" + - seeds: "172.17.0.2" # For workloads with more data than can fit in memory, Cassandra's # bottleneck will be reads that need to fetch data from @@ -572,7 +572,7 @@ ssl_storage_port: 7001 # # Setting listen_address to 0.0.0.0 is always wrong. # -listen_address: 172.17.0.3 +listen_address: 172.17.0.2 # Set listen_address OR listen_interface, not both. Interfaces must correspond # to a single address, IP aliasing is not supported. @@ -586,7 +586,7 @@ listen_address: 172.17.0.3 # Address to broadcast to other Cassandra nodes # Leaving this blank will set it to the same value as listen_address -broadcast_address: 172.17.0.3 +broadcast_address: 172.17.0.2 # When using multiple physical network interfaces, set this # to true to listen on broadcast_address in addition to @@ -668,7 +668,7 @@ rpc_port: 9160 # be set to 0.0.0.0. If left blank, this will be set to the value of # rpc_address. If rpc_address is set to 0.0.0.0, broadcast_rpc_address must # be set. -broadcast_rpc_address: 172.17.0.3 +broadcast_rpc_address: 172.17.0.2 # enable or disable keepalive on rpc/native connections rpc_keepalive: true From b87f8a13ed1146de9b97fe67d2cd359a6af97a12 Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Mon, 1 May 2017 14:59:55 -0700 Subject: [PATCH 117/152] Update interface name from Wrapper to a more descriptive RunnerUtil --- builtin/logical/database/dbplugin/client.go | 2 +- builtin/logical/database/dbplugin/plugin.go | 2 +- helper/pluginutil/runner.go | 8 ++++---- helper/pluginutil/tls.go | 2 +- 4 files changed, 7 insertions(+), 7 deletions(-) diff --git a/builtin/logical/database/dbplugin/client.go b/builtin/logical/database/dbplugin/client.go index 8cfc3aad00a1..0c095f891035 100644 --- a/builtin/logical/database/dbplugin/client.go +++ b/builtin/logical/database/dbplugin/client.go @@ -29,7 +29,7 @@ func (dc *DatabasePluginClient) Close() error { // newPluginClient returns a databaseRPCClient with a connection to a running // plugin. The client is wrapped in a DatabasePluginClient object to ensure the // plugin is killed on call of Close(). -func newPluginClient(sys pluginutil.Wrapper, pluginRunner *pluginutil.PluginRunner) (Database, error) { +func newPluginClient(sys pluginutil.RunnerUtil, pluginRunner *pluginutil.PluginRunner) (Database, error) { // pluginMap is the map of plugins we can dispense. var pluginMap = map[string]plugin.Plugin{ "database": new(DatabasePlugin), diff --git a/builtin/logical/database/dbplugin/plugin.go b/builtin/logical/database/dbplugin/plugin.go index 21812423c1a2..941f7aa04d45 100644 --- a/builtin/logical/database/dbplugin/plugin.go +++ b/builtin/logical/database/dbplugin/plugin.go @@ -31,7 +31,7 @@ type Statements struct { // PluginFactory is used to build plugin database types. It wraps the database // object in a logging and metrics middleware. -func PluginFactory(pluginName string, sys pluginutil.LookWrapper, logger log.Logger) (Database, error) { +func PluginFactory(pluginName string, sys pluginutil.LookRunnerUtil, logger log.Logger) (Database, error) { // Look for plugin in the plugin catalog pluginRunner, err := sys.LookupPlugin(pluginName) if err != nil { diff --git a/helper/pluginutil/runner.go b/helper/pluginutil/runner.go index 6a8df73855d8..0617f7624528 100644 --- a/helper/pluginutil/runner.go +++ b/helper/pluginutil/runner.go @@ -19,15 +19,15 @@ type Looker interface { // Wrapper interface defines the functions needed by the runner to wrap the // metadata needed to run a plugin process. This includes looking up Mlock // configuration and wrapping data in a respose wrapped token. -type Wrapper interface { +type RunnerUtil interface { ResponseWrapData(data map[string]interface{}, ttl time.Duration, jwt bool) (*wrapping.ResponseWrapInfo, error) MlockEnabled() bool } // LookWrapper defines the functions for both Looker and Wrapper -type LookWrapper interface { +type LookRunnerUtil interface { Looker - Wrapper + RunnerUtil } // PluginRunner defines the metadata needed to run a plugin securely with @@ -43,7 +43,7 @@ type PluginRunner struct { // Run takes a wrapper instance, and the go-plugin paramaters and executes a // plugin. -func (r *PluginRunner) Run(wrapper Wrapper, pluginMap map[string]plugin.Plugin, hs plugin.HandshakeConfig, env []string) (*plugin.Client, error) { +func (r *PluginRunner) Run(wrapper RunnerUtil, pluginMap map[string]plugin.Plugin, hs plugin.HandshakeConfig, env []string) (*plugin.Client, error) { // Get a CA TLS Certificate certBytes, key, err := GenerateCert() if err != nil { diff --git a/helper/pluginutil/tls.go b/helper/pluginutil/tls.go index ee0c54d89d4c..05804a33bbf4 100644 --- a/helper/pluginutil/tls.go +++ b/helper/pluginutil/tls.go @@ -97,7 +97,7 @@ func CreateClientTLSConfig(certBytes []byte, key *ecdsa.PrivateKey) (*tls.Config // WrapServerConfig is used to create a server certificate and private key, then // wrap them in an unwrap token for later retrieval by the plugin. -func WrapServerConfig(sys Wrapper, certBytes []byte, key *ecdsa.PrivateKey) (string, error) { +func WrapServerConfig(sys RunnerUtil, certBytes []byte, key *ecdsa.PrivateKey) (string, error) { rawKey, err := x509.MarshalECPrivateKey(key) if err != nil { return "", err From 6ca436cdf5f6c7070fa8c950ac93bf6eb4595326 Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Mon, 1 May 2017 15:30:56 -0700 Subject: [PATCH 118/152] Don't store an error response as a package variable --- builtin/logical/database/path_config_connection.go | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/builtin/logical/database/path_config_connection.go b/builtin/logical/database/path_config_connection.go index 4c0863fd7e10..f3767428586e 100644 --- a/builtin/logical/database/path_config_connection.go +++ b/builtin/logical/database/path_config_connection.go @@ -11,8 +11,8 @@ import ( ) var ( - respErrEmptyPluginName = logical.ErrorResponse("empty plugin name") - respErrEmptyName = logical.ErrorResponse("empty name attribute given") + respErrEmptyPluginName = "empty plugin name" + respErrEmptyName = "empty name attribute given" ) // DatabaseConfig is used by the Factory function to configure a Database @@ -51,7 +51,7 @@ func (b *databaseBackend) pathConnectionReset() framework.OperationFunc { return func(req *logical.Request, data *framework.FieldData) (*logical.Response, error) { name := data.Get("name").(string) if name == "" { - return respErrEmptyName, nil + return logical.ErrorResponse(respErrEmptyName), nil } // Grab the mutex lock @@ -120,7 +120,7 @@ func (b *databaseBackend) connectionReadHandler() framework.OperationFunc { return func(req *logical.Request, data *framework.FieldData) (*logical.Response, error) { name := data.Get("name").(string) if name == "" { - return respErrEmptyName, nil + return logical.ErrorResponse(respErrEmptyName), nil } entry, err := req.Storage.Get(fmt.Sprintf("config/%s", name)) @@ -146,7 +146,7 @@ func (b *databaseBackend) connectionDeleteHandler() framework.OperationFunc { return func(req *logical.Request, data *framework.FieldData) (*logical.Response, error) { name := data.Get("name").(string) if name == "" { - return respErrEmptyName, nil + return logical.ErrorResponse(respErrEmptyName), nil } err := req.Storage.Delete(fmt.Sprintf("config/%s", name)) @@ -176,12 +176,12 @@ func (b *databaseBackend) connectionWriteHandler() framework.OperationFunc { return func(req *logical.Request, data *framework.FieldData) (*logical.Response, error) { pluginName := data.Get("plugin_name").(string) if pluginName == "" { - return respErrEmptyPluginName, nil + return logical.ErrorResponse(respErrEmptyPluginName), nil } name := data.Get("name").(string) if name == "" { - return respErrEmptyName, nil + return logical.ErrorResponse(respErrEmptyName), nil } verifyConnection := data.Get("verify_connection").(bool) From 66630f642dabc85d98f160d8906857679bb06123 Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Mon, 1 May 2017 15:43:21 -0700 Subject: [PATCH 119/152] Add test for custiom mssql revoke statement --- plugins/database/mssql/mssql_test.go | 25 +++++++++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/plugins/database/mssql/mssql_test.go b/plugins/database/mssql/mssql_test.go index 0dc18cb3e2a3..830e38abbd33 100644 --- a/plugins/database/mssql/mssql_test.go +++ b/plugins/database/mssql/mssql_test.go @@ -122,6 +122,26 @@ func TestMSSQL_RevokeUser(t *testing.T) { if err := testCredsExist(t, connURL, username, password); err == nil { t.Fatal("Credentials were not revoked") } + + username, password, err = db.CreateUser(statements, "test", time.Now().Add(2*time.Second)) + if err != nil { + t.Fatalf("err: %s", err) + } + + if err = testCredsExist(t, connURL, username, password); err != nil { + t.Fatalf("Could not connect with new credentials: %s", err) + } + + // Test custom revoke statememt + statements.RevocationStatements = testMSSQLDrop + err = db.RevokeUser(statements, username) + if err != nil { + t.Fatalf("err: %s", err) + } + + if err := testCredsExist(t, connURL, username, password); err == nil { + t.Fatal("Credentials were not revoked") + } } func testCredsExist(t testing.TB, connURL, username, password string) error { @@ -140,3 +160,8 @@ const testMSSQLRole = ` CREATE LOGIN [{{name}}] WITH PASSWORD = '{{password}}'; CREATE USER [{{name}}] FOR LOGIN [{{name}}]; GRANT SELECT, INSERT, UPDATE, DELETE ON SCHEMA::dbo TO [{{name}}];` + +const testMSSQLDrop = ` +DROP USER [{{name}}]; +DROP LOGIN [{{name}}]; +` From d68f2837f6b714b2c383cb14c10d80b3769cb354 Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Mon, 1 May 2017 15:45:17 -0700 Subject: [PATCH 120/152] Prepend a 'v-' to the sql username strings --- plugins/helper/database/credsutil/sql.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/plugins/helper/database/credsutil/sql.go b/plugins/helper/database/credsutil/sql.go index 23e98102f3e8..a7929ccb1a80 100644 --- a/plugins/helper/database/credsutil/sql.go +++ b/plugins/helper/database/credsutil/sql.go @@ -21,7 +21,7 @@ func (scp *SQLCredentialsProducer) GenerateUsername(displayName string) (string, if err != nil { return "", err } - username := fmt.Sprintf("%s-%s", displayName, userUUID) + username := fmt.Sprintf("v-%s-%s", displayName, userUUID) if scp.UsernameLen > 0 && len(username) > scp.UsernameLen { username = username[:scp.UsernameLen] } From 885398e341068223a7045f757bc29b715713e4e1 Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Tue, 2 May 2017 01:59:36 -0700 Subject: [PATCH 121/152] Add internals doc for plugins --- website/source/docs/internals/plugins.html.md | 102 ++++++++++++++++++ 1 file changed, 102 insertions(+) create mode 100644 website/source/docs/internals/plugins.html.md diff --git a/website/source/docs/internals/plugins.html.md b/website/source/docs/internals/plugins.html.md new file mode 100644 index 000000000000..5c396573db8a --- /dev/null +++ b/website/source/docs/internals/plugins.html.md @@ -0,0 +1,102 @@ +--- +layout: "docs" +page_title: "Plugin System" +sidebar_current: "docs-internals-plugins" +description: |- + Learn about Vault's plugin system. +--- + +# Plugin System +Certain Vault backends utilize plugins to extend their functionality outside of +what is available in the core vault code. Often times these backends will +provide both builtin plugins and a mechanism for executing external plugins. +Builtin plugins are shipped with vault, often for commonly used implementations, +and require no additional operator intervention to run. Builtin plugins are +just like any other backend code inside vault. External plugins, on the other +hand, are not shipped with the vault binary and must be registered to vault by +a privileged vault user. This section of the documentation will describe the +architecture and security of external plugins. + +# Plugin Architecture +Vault's plugins are completely separate, standalone applications that Vault +executes and communicates with over RPC. This means the plugin process does not +share the same memory space as Vault and therefore can only access the +interfaces and arguments given to it. This also means a crash in a plugin can not +crash the entirety of Vault. + +## Plugin Communication +Vault creates a mutually authenticated TLS connection for communication with the +plugin's RPC server. While invoking the plugin process Vault passes a [wrapping +token](https://www.vaultproject.io/docs/concepts/response-wrapping.html) to the +plugin process' environment. This token is single use and has a short TTL. Once +unwrapped, it provides the plugin with a unique generated TLS certificate and +private key for it to use to talk to the original vault process. + +## Plugin Registration +An important aspect of Vault's plugin system is designed to ensure the plugin +invoked by vault is authentic and maintains integrity. There are two components +that a Vault operator needs to configure before external plugins can be run. + +### Plugin Directory +The plugin directory is a configuration option of Vault, and can be specified in +the [configuration file](https://www.vaultproject.io/docs/configuration/index.html). +This setting specifies a directory that all plugin binaries must live. A plugin +can not be added to vault unless it exists in the plugin directory. There is no +default for this configuration option, and if it is not set plugins can not be +added to vault. + +~> Warning: A vault operator should take care to lock down the permissions on +this directory to ensure a plugin can not be modified by an unauthorized user +between the time of the SHA check and the time of plugin execution. + +### Plugin Catalog +The plugin catalog is Vault's list of approved plugins. The catalog is stored in +Vault's barrier and can only be updated by a vault user with sudo permissions. +Upon adding a new plugin the SHA256 sum of the executable and the command that +should be used to run the plugin must be provided. The catalog will make sure +the executable referenced in the command exists in the plugin directory. When +added to the catalog the plugin is not automatically executed, it instead +becomes visible to backends and can be executed by them. + +### Plugin Execution +When a backend executes a plugin it first checks the executable's SHA256 sum +against the one configured in the plugin catalog. Like Vault, plugins support +the use of mlock when availible. + +# Plugin Development +Because Vault communicates to plugins over a RPC interface, you can build and +distribute a plugin for Vault without having to rebuild Vault itself. This makes +it easy for you to build a Vault plugin for your organization's internal use, +for a proprietary API that you don't want to open source, or to prototype +something before contributing it back to the main project. + +In theory, because the plugin interface is HTTP, you could even develop a plugin +using a completely different programming language! (Disclaimer, you would also +have to re-implement the plugin API which is not a trivial amount of work.) + +~> Advanced topic! Plugin development is a highly advanced topic in Vault, and +is not required knowledge for day-to-day usage. If you don't plan on writing any +plugins, we recommend not reading this section of the documentation. + +Developing a plugin is simple. The only knowledge necessary to write +a plugin is basic command-line skills and basic knowledge of the +[Go programming language](http://golang.org). + +You're plugin implementation just needs to satisfy the interface for the plugin +type you want to build. You can find these definitions in the docs for the +backend running the plugin. + +```go +package main + +import ( + plugin "github.com/hashicorp/vault/builtin/logcial/database/dbplugin" +) + +func main() { + plugin.Serve(new(MyPlugin)) +} +``` + +And that's basically it! You would just need to change MyPlugin to your actual +plugin. From 31541b7fddeaf2b2ac9a72b7db85ad33efda9536 Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Tue, 2 May 2017 02:00:04 -0700 Subject: [PATCH 122/152] Add plugins interal page to the sidebar: --- website/source/layouts/docs.erb | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/website/source/layouts/docs.erb b/website/source/layouts/docs.erb index 8f2686e64439..32e2a7e7a003 100644 --- a/website/source/layouts/docs.erb +++ b/website/source/layouts/docs.erb @@ -35,6 +35,10 @@ > Replication + + > + Plugins + From 6ddfe9aa7f0d2ee9a36e22a206e2e3e1e6c48cc4 Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Tue, 2 May 2017 02:00:39 -0700 Subject: [PATCH 123/152] Rename NewPluginServer to just Serve --- builtin/logical/database/dbplugin/server.go | 4 ++-- plugins/database/mssql/mssql.go | 2 +- plugins/database/mysql/mysql.go | 2 +- plugins/database/postgresql/postgresql.go | 2 +- 4 files changed, 5 insertions(+), 5 deletions(-) diff --git a/builtin/logical/database/dbplugin/server.go b/builtin/logical/database/dbplugin/server.go index 04cc3d7e9041..32c377e13138 100644 --- a/builtin/logical/database/dbplugin/server.go +++ b/builtin/logical/database/dbplugin/server.go @@ -7,10 +7,10 @@ import ( "github.com/hashicorp/vault/helper/pluginutil" ) -// NewPluginServer is called from within a plugin and wraps the provided +// Serve is called from within a plugin and wraps the provided // Database implementation in a databasePluginRPCServer object and starts a // RPC server. -func NewPluginServer(db Database) { +func Serve(db Database) { dbPlugin := &DatabasePlugin{ impl: db, } diff --git a/plugins/database/mssql/mssql.go b/plugins/database/mssql/mssql.go index b608428e5cf0..d82efce6f729 100644 --- a/plugins/database/mssql/mssql.go +++ b/plugins/database/mssql/mssql.go @@ -45,7 +45,7 @@ func Run() error { return err } - dbplugin.NewPluginServer(dbType.(*MSSQL)) + dbplugin.Serve(dbType.(*MSSQL)) return nil } diff --git a/plugins/database/mysql/mysql.go b/plugins/database/mysql/mysql.go index 6485aaa8625d..7eb680759f3e 100644 --- a/plugins/database/mysql/mysql.go +++ b/plugins/database/mysql/mysql.go @@ -48,7 +48,7 @@ func Run() error { return err } - dbplugin.NewPluginServer(dbType.(*MySQL)) + dbplugin.Serve(dbType.(*MySQL)) return nil } diff --git a/plugins/database/postgresql/postgresql.go b/plugins/database/postgresql/postgresql.go index e90e0f8cbcbf..0889a86f554d 100644 --- a/plugins/database/postgresql/postgresql.go +++ b/plugins/database/postgresql/postgresql.go @@ -41,7 +41,7 @@ func Run() error { return err } - dbplugin.NewPluginServer(dbType.(*PostgreSQL)) + dbplugin.Serve(dbType.(*PostgreSQL)) return nil } From 7f92c5f47f480b5a371f46bcbcdc5c9c4d20808e Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Tue, 2 May 2017 02:22:06 -0700 Subject: [PATCH 124/152] Fix documentation --- plugins/database/postgresql/postgresql.go | 2 +- website/source/docs/internals/plugins.html.md | 20 ++++++++++--------- 2 files changed, 12 insertions(+), 10 deletions(-) diff --git a/plugins/database/postgresql/postgresql.go b/plugins/database/postgresql/postgresql.go index 0889a86f554d..bc5b14544d22 100644 --- a/plugins/database/postgresql/postgresql.go +++ b/plugins/database/postgresql/postgresql.go @@ -34,7 +34,7 @@ func New() (interface{}, error) { return dbType, nil } -// Run instatiates a PostgreSQL object, and runs the RPC server for the plugin +// Run instantiates a PostgreSQL object, and runs the RPC server for the plugin func Run() error { dbType, err := New() if err != nil { diff --git a/website/source/docs/internals/plugins.html.md b/website/source/docs/internals/plugins.html.md index 5c396573db8a..a3baafff07d0 100644 --- a/website/source/docs/internals/plugins.html.md +++ b/website/source/docs/internals/plugins.html.md @@ -33,7 +33,7 @@ unwrapped, it provides the plugin with a unique generated TLS certificate and private key for it to use to talk to the original vault process. ## Plugin Registration -An important aspect of Vault's plugin system is designed to ensure the plugin +An important consideration of Vault's plugin system is to ensure the plugin invoked by vault is authentic and maintains integrity. There are two components that a Vault operator needs to configure before external plugins can be run. @@ -52,16 +52,18 @@ between the time of the SHA check and the time of plugin execution. ### Plugin Catalog The plugin catalog is Vault's list of approved plugins. The catalog is stored in Vault's barrier and can only be updated by a vault user with sudo permissions. -Upon adding a new plugin the SHA256 sum of the executable and the command that -should be used to run the plugin must be provided. The catalog will make sure -the executable referenced in the command exists in the plugin directory. When -added to the catalog the plugin is not automatically executed, it instead -becomes visible to backends and can be executed by them. +Upon adding a new plugin the plugin name, SHA256 sum of the executable, and the +command that should be used to run the plugin must be provided. The catalog will +make sure the executable referenced in the command exists in the plugin +directory. When added to the catalog the plugin is not automatically executed, +it instead becomes visible to backends and can be executed by them. ### Plugin Execution -When a backend executes a plugin it first checks the executable's SHA256 sum -against the one configured in the plugin catalog. Like Vault, plugins support -the use of mlock when availible. +When a backend wants to run a plugin, it first looks up the plugin, by name, in +the catalog. It then checks the executable's SHA256 sum against the one +configured in the plugin catalog. Finally vault runs the command configured in +the catalog, sending along the JWT formatted response wrapping token and mlock +settings (like Vault, plugins support the use of mlock when availible). # Plugin Development Because Vault communicates to plugins over a RPC interface, you can build and From d300c235979f564b70cca6bfa461a4cc4c00e3b2 Mon Sep 17 00:00:00 2001 From: Jeff Mitchell Date: Tue, 2 May 2017 16:26:32 -0400 Subject: [PATCH 125/152] Add website skeleton --- .../docs/secrets/databases/cassandra.html.md | 9 +++++++ .../docs/secrets/databases/index.html.md | 11 ++++++++ .../docs/secrets/databases/mssql.html.md | 9 +++++++ .../secrets/databases/mysql-maria.html.md | 9 +++++++ .../docs/secrets/databases/postgresql.html.md | 9 +++++++ website/source/layouts/docs.erb | 26 ++++++++++++++++--- 6 files changed, 69 insertions(+), 4 deletions(-) create mode 100644 website/source/docs/secrets/databases/cassandra.html.md create mode 100644 website/source/docs/secrets/databases/index.html.md create mode 100644 website/source/docs/secrets/databases/mssql.html.md create mode 100644 website/source/docs/secrets/databases/mysql-maria.html.md create mode 100644 website/source/docs/secrets/databases/postgresql.html.md diff --git a/website/source/docs/secrets/databases/cassandra.html.md b/website/source/docs/secrets/databases/cassandra.html.md new file mode 100644 index 000000000000..012e7db5b7d2 --- /dev/null +++ b/website/source/docs/secrets/databases/cassandra.html.md @@ -0,0 +1,9 @@ +--- +layout: "docs" +page_title: "Cassandra Database Plugin" +sidebar_current: "docs-secrets-databases-cassandra" +description: |- + The Cassandra plugin for Vault's Database backend generates database credentials to access Cassandra. +--- + +# Cassandra Database Plugin diff --git a/website/source/docs/secrets/databases/index.html.md b/website/source/docs/secrets/databases/index.html.md new file mode 100644 index 000000000000..20f7bbed2725 --- /dev/null +++ b/website/source/docs/secrets/databases/index.html.md @@ -0,0 +1,11 @@ +--- +layout: "docs" +page_title: "Databases" +sidebar_current: "docs-secrets-databases" +description: |- + Top page for database secret backend information +--- + +# Databases + +Something diff --git a/website/source/docs/secrets/databases/mssql.html.md b/website/source/docs/secrets/databases/mssql.html.md new file mode 100644 index 000000000000..32ecf7775eaf --- /dev/null +++ b/website/source/docs/secrets/databases/mssql.html.md @@ -0,0 +1,9 @@ +--- +layout: "docs" +page_title: "MSSQL Database Plugin" +sidebar_current: "docs-secrets-databases-mssql" +description: |- + The MSSQL plugin for Vault's Database backend generates database credentials to access Microsoft SQL Server. +--- + +# MSSQL Database Plugin diff --git a/website/source/docs/secrets/databases/mysql-maria.html.md b/website/source/docs/secrets/databases/mysql-maria.html.md new file mode 100644 index 000000000000..1ee601dbdbe3 --- /dev/null +++ b/website/source/docs/secrets/databases/mysql-maria.html.md @@ -0,0 +1,9 @@ +--- +layout: "docs" +page_title: "MySQL/MariaDB Database Plugin" +sidebar_current: "docs-secrets-databases-mysql-maria" +description: |- + The MySQL/MariaDB plugin for Vault's Database backend generates database credentials to access MySQL and MariaDB servers. +--- + +# MySQL/MariaDB Database Plugin diff --git a/website/source/docs/secrets/databases/postgresql.html.md b/website/source/docs/secrets/databases/postgresql.html.md new file mode 100644 index 000000000000..5de340043d72 --- /dev/null +++ b/website/source/docs/secrets/databases/postgresql.html.md @@ -0,0 +1,9 @@ +--- +layout: "docs" +page_title: "PostgreSQL Database Plugin" +sidebar_current: "docs-secrets-databases-postgresql" +description: |- + The PostgreSQL plugin for Vault's Database backend generates database credentials to access PostgreSQL. +--- + +# PostgreSQL Database Plugin diff --git a/website/source/layouts/docs.erb b/website/source/layouts/docs.erb index 32e2a7e7a003..a6afd9c2d921 100644 --- a/website/source/layouts/docs.erb +++ b/website/source/layouts/docs.erb @@ -208,7 +208,7 @@ > - Cassandra + Cassandra (Deprecated) > @@ -219,6 +219,24 @@ Cubbyhole + > + Databases (Beta) + + + > Generic @@ -228,11 +246,11 @@ > - MSSQL + MSSQL (Deprecated) > - MySQL + MySQL (Deprecated) > @@ -240,7 +258,7 @@ > - PostgreSQL + PostgreSQL (Deprecated) > From 1df8ec9ef73fc157704167975c3a9ec75cb4c0d5 Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Tue, 2 May 2017 14:40:11 -0700 Subject: [PATCH 126/152] Update the api for serving plugins and provide a utility to pass TLS data for commuinicating with the vault process --- builtin/logical/database/backend_test.go | 68 ++++-- .../logical/database/dbplugin/plugin_test.go | 55 +++-- builtin/logical/database/dbplugin/server.go | 13 +- helper/pluginutil/runner.go | 42 ++++ helper/pluginutil/tls.go | 213 +++++++++--------- .../cassandra-database-plugin/main.go | 7 +- plugins/database/cassandra/cassandra.go | 6 +- .../mssql/mssql-database-plugin/main.go | 7 +- plugins/database/mssql/mssql.go | 6 +- .../mysql/mysql-database-plugin/main.go | 7 +- plugins/database/mysql/mysql.go | 6 +- .../postgresql-database-plugin/main.go | 7 +- plugins/database/postgresql/postgresql.go | 6 +- plugins/serve.go | 31 +++ vault/testing.go | 3 + 15 files changed, 317 insertions(+), 160 deletions(-) create mode 100644 plugins/serve.go diff --git a/builtin/logical/database/backend_test.go b/builtin/logical/database/backend_test.go index 08317cbdc42f..70ec22ee2fed 100644 --- a/builtin/logical/database/backend_test.go +++ b/builtin/logical/database/backend_test.go @@ -4,12 +4,13 @@ import ( "database/sql" "fmt" "log" - "net" + stdhttp "net/http" "os" "reflect" "sync" "testing" + "github.com/hashicorp/vault/api" "github.com/hashicorp/vault/builtin/logical/database/dbplugin" "github.com/hashicorp/vault/helper/pluginutil" "github.com/hashicorp/vault/http" @@ -77,13 +78,30 @@ func preparePostgresTestContainer(t *testing.T, s logical.Storage, b logical.Bac return } -func getCore(t *testing.T) (*vault.Core, net.Listener, logical.SystemView, string) { - core, _, token, ln := vault.TestCoreUnsealedWithListener(t) - http.TestServerWithListener(t, ln, "", core) - sys := vault.TestDynamicSystemView(core) - vault.TestAddTestPlugin(t, core, "postgresql-database-plugin", "TestBackend_PluginMain") +func getCore(t *testing.T) ([]*vault.TestClusterCore, logical.SystemView) { + coreConfig := &vault.CoreConfig{ + LogicalBackends: map[string]logical.Factory{ + "database": Factory, + }, + } + + handler1 := stdhttp.NewServeMux() + handler2 := stdhttp.NewServeMux() + handler3 := stdhttp.NewServeMux() + + // Chicken-and-egg: Handler needs a core. So we create handlers first, then + // add routes chained to a Handler-created handler. + cores := vault.TestCluster(t, []stdhttp.Handler{handler1, handler2, handler3}, coreConfig, false) + handler1.Handle("/", http.Handler(cores[0].Core)) + handler2.Handle("/", http.Handler(cores[1].Core)) + handler3.Handle("/", http.Handler(cores[2].Core)) - return core, ln, sys, token + core := cores[0] + + sys := vault.TestDynamicSystemView(core.Core) + vault.TestAddTestPlugin(t, core.Core, "postgresql-database-plugin", "TestBackend_PluginMain") + + return cores, sys } func TestBackend_PluginMain(t *testing.T) { @@ -91,14 +109,20 @@ func TestBackend_PluginMain(t *testing.T) { return } - postgresql.Run() + err := postgresql.Run(&api.TLSConfig{Insecure: true}) + if err != nil { + t.Fatal(err) + } + t.Fatal("We shouldn't get here") } func TestBackend_config_connection(t *testing.T) { var resp *logical.Response var err error - _, ln, sys, _ := getCore(t) - defer ln.Close() + cores, sys := getCore(t) + for _, core := range cores { + defer core.CloseListeners() + } config := logical.TestBackendConfig() config.StorageView = &logical.InmemStorage{} @@ -147,8 +171,10 @@ func TestBackend_config_connection(t *testing.T) { } func TestBackend_basic(t *testing.T) { - _, ln, sys, _ := getCore(t) - defer ln.Close() + cores, sys := getCore(t) + for _, core := range cores { + defer core.CloseListeners() + } config := logical.TestBackendConfig() config.StorageView = &logical.InmemStorage{} @@ -238,8 +264,10 @@ func TestBackend_basic(t *testing.T) { } func TestBackend_connectionCrud(t *testing.T) { - _, ln, sys, _ := getCore(t) - defer ln.Close() + cores, sys := getCore(t) + for _, core := range cores { + defer core.CloseListeners() + } config := logical.TestBackendConfig() config.StorageView = &logical.InmemStorage{} @@ -383,8 +411,10 @@ func TestBackend_connectionCrud(t *testing.T) { } func TestBackend_roleCrud(t *testing.T) { - _, ln, sys, _ := getCore(t) - defer ln.Close() + cores, sys := getCore(t) + for _, core := range cores { + defer core.CloseListeners() + } config := logical.TestBackendConfig() config.StorageView = &logical.InmemStorage{} @@ -493,8 +523,10 @@ func TestBackend_roleCrud(t *testing.T) { } } func TestBackend_allowedRoles(t *testing.T) { - _, ln, sys, _ := getCore(t) - defer ln.Close() + cores, sys := getCore(t) + for _, core := range cores { + defer core.CloseListeners() + } config := logical.TestBackendConfig() config.StorageView = &logical.InmemStorage{} diff --git a/builtin/logical/database/dbplugin/plugin_test.go b/builtin/logical/database/dbplugin/plugin_test.go index 1587ba24a5b4..c38d85ed3969 100644 --- a/builtin/logical/database/dbplugin/plugin_test.go +++ b/builtin/logical/database/dbplugin/plugin_test.go @@ -2,15 +2,17 @@ package dbplugin_test import ( "errors" - "net" + stdhttp "net/http" "os" "testing" "time" + "github.com/hashicorp/vault/api" "github.com/hashicorp/vault/builtin/logical/database/dbplugin" "github.com/hashicorp/vault/helper/pluginutil" "github.com/hashicorp/vault/http" "github.com/hashicorp/vault/logical" + "github.com/hashicorp/vault/plugins" "github.com/hashicorp/vault/vault" log "github.com/mgutz/logxi/v1" ) @@ -72,13 +74,26 @@ func (m *mockPlugin) Close() error { return nil } -func getCore(t *testing.T) (*vault.Core, net.Listener, logical.SystemView) { - core, _, _, ln := vault.TestCoreUnsealedWithListener(t) - http.TestServerWithListener(t, ln, "", core) - sys := vault.TestDynamicSystemView(core) - vault.TestAddTestPlugin(t, core, "test-plugin", "TestPlugin_Main") +func getCore(t *testing.T) ([]*vault.TestClusterCore, logical.SystemView) { + coreConfig := &vault.CoreConfig{} - return core, ln, sys + handler1 := stdhttp.NewServeMux() + handler2 := stdhttp.NewServeMux() + handler3 := stdhttp.NewServeMux() + + // Chicken-and-egg: Handler needs a core. So we create handlers first, then + // add routes chained to a Handler-created handler. + cores := vault.TestCluster(t, []stdhttp.Handler{handler1, handler2, handler3}, coreConfig, false) + handler1.Handle("/", http.Handler(cores[0].Core)) + handler2.Handle("/", http.Handler(cores[1].Core)) + handler3.Handle("/", http.Handler(cores[2].Core)) + + core := cores[0] + + sys := vault.TestDynamicSystemView(core.Core) + vault.TestAddTestPlugin(t, core.Core, "test-plugin", "TestPlugin_Main") + + return cores, sys } // This is not an actual test case, it's a helper function that will be executed @@ -92,12 +107,14 @@ func TestPlugin_Main(t *testing.T) { users: make(map[string][]string), } - dbplugin.NewPluginServer(plugin) + plugins.Serve(plugin, &api.TLSConfig{Insecure: true}) } func TestPlugin_Initialize(t *testing.T) { - _, ln, sys := getCore(t) - defer ln.Close() + cores, sys := getCore(t) + for _, core := range cores { + defer core.CloseListeners() + } dbRaw, err := dbplugin.PluginFactory("test-plugin", sys, &log.NullLogger{}) if err != nil { @@ -120,8 +137,10 @@ func TestPlugin_Initialize(t *testing.T) { } func TestPlugin_CreateUser(t *testing.T) { - _, ln, sys := getCore(t) - defer ln.Close() + cores, sys := getCore(t) + for _, core := range cores { + defer core.CloseListeners() + } db, err := dbplugin.PluginFactory("test-plugin", sys, &log.NullLogger{}) if err != nil { @@ -155,8 +174,10 @@ func TestPlugin_CreateUser(t *testing.T) { } func TestPlugin_RenewUser(t *testing.T) { - _, ln, sys := getCore(t) - defer ln.Close() + cores, sys := getCore(t) + for _, core := range cores { + defer core.CloseListeners() + } db, err := dbplugin.PluginFactory("test-plugin", sys, &log.NullLogger{}) if err != nil { @@ -184,8 +205,10 @@ func TestPlugin_RenewUser(t *testing.T) { } func TestPlugin_RevokeUser(t *testing.T) { - _, ln, sys := getCore(t) - defer ln.Close() + cores, sys := getCore(t) + for _, core := range cores { + defer core.CloseListeners() + } db, err := dbplugin.PluginFactory("test-plugin", sys, &log.NullLogger{}) if err != nil { diff --git a/builtin/logical/database/dbplugin/server.go b/builtin/logical/database/dbplugin/server.go index 32c377e13138..9546d092c276 100644 --- a/builtin/logical/database/dbplugin/server.go +++ b/builtin/logical/database/dbplugin/server.go @@ -1,16 +1,15 @@ package dbplugin import ( - "fmt" + "crypto/tls" "github.com/hashicorp/go-plugin" - "github.com/hashicorp/vault/helper/pluginutil" ) // Serve is called from within a plugin and wraps the provided // Database implementation in a databasePluginRPCServer object and starts a // RPC server. -func Serve(db Database) { +func Serve(db Database, tlsProvider func() (*tls.Config, error)) { dbPlugin := &DatabasePlugin{ impl: db, } @@ -20,16 +19,10 @@ func Serve(db Database) { "database": dbPlugin, } - err := pluginutil.OptionallyEnableMlock() - if err != nil { - fmt.Println(err) - return - } - plugin.Serve(&plugin.ServeConfig{ HandshakeConfig: handshakeConfig, Plugins: pluginMap, - TLSProvider: pluginutil.VaultPluginTLSProvider, + TLSProvider: tlsProvider, }) } diff --git a/helper/pluginutil/runner.go b/helper/pluginutil/runner.go index 0617f7624528..91439a3b8b6e 100644 --- a/helper/pluginutil/runner.go +++ b/helper/pluginutil/runner.go @@ -2,11 +2,13 @@ package pluginutil import ( "crypto/sha256" + "flag" "fmt" "os/exec" "time" plugin "github.com/hashicorp/go-plugin" + "github.com/hashicorp/vault/api" "github.com/hashicorp/vault/helper/wrapping" ) @@ -87,3 +89,43 @@ func (r *PluginRunner) Run(wrapper RunnerUtil, pluginMap map[string]plugin.Plugi return client, nil } + +type APIClientMeta struct { + // These are set by the command line flags. + flagCACert string + flagCAPath string + flagClientCert string + flagClientKey string + flagInsecure bool +} + +func (f *APIClientMeta) FlagSet() *flag.FlagSet { + fs := flag.NewFlagSet("tls settings", flag.ContinueOnError) + + fs.StringVar(&f.flagCACert, "ca-cert", "", "") + fs.StringVar(&f.flagCAPath, "ca-path", "", "") + fs.StringVar(&f.flagClientCert, "client-cert", "", "") + fs.StringVar(&f.flagClientKey, "client-key", "", "") + fs.BoolVar(&f.flagInsecure, "insecure", false, "") + fs.BoolVar(&f.flagInsecure, "tls-skip-verify", false, "") + + return fs +} + +func (f *APIClientMeta) GetTLSConfig() *api.TLSConfig { + // If we need custom TLS configuration, then set it + if f.flagCACert != "" || f.flagCAPath != "" || f.flagClientCert != "" || f.flagClientKey != "" || f.flagInsecure { + t := &api.TLSConfig{ + CACert: f.flagCACert, + CAPath: f.flagCAPath, + ClientCert: f.flagClientCert, + ClientKey: f.flagClientKey, + TLSServerName: "", + Insecure: f.flagInsecure, + } + + return t + } + + return nil +} diff --git a/helper/pluginutil/tls.go b/helper/pluginutil/tls.go index 05804a33bbf4..b355079d6e24 100644 --- a/helper/pluginutil/tls.go +++ b/helper/pluginutil/tls.go @@ -116,109 +116,114 @@ func WrapServerConfig(sys RunnerUtil, certBytes []byte, key *ecdsa.PrivateKey) ( // VaultPluginTLSProvider is run inside a plugin and retrives the response // wrapped TLS certificate from vault. It returns a configured TLS Config. -func VaultPluginTLSProvider() (*tls.Config, error) { - unwrapToken := os.Getenv(PluginUnwrapTokenEnv) - - // Ensure unwrap token is a JWT - if strings.Count(unwrapToken, ".") != 2 { - return nil, errors.New("Could not parse unwraptoken") - } - - // Parse the JWT and retrieve the vault address - wt, err := jws.ParseJWT([]byte(unwrapToken)) - if err != nil { - return nil, errors.New(fmt.Sprintf("error decoding token: %s", err)) - } - if wt == nil { - return nil, errors.New("nil decoded token") - } - - addrRaw := wt.Claims().Get("addr") - if addrRaw == nil { - return nil, errors.New("decoded token does not contain primary cluster address") - } - vaultAddr, ok := addrRaw.(string) - if !ok { - return nil, errors.New("decoded token's address not valid") - } - if vaultAddr == "" { - return nil, errors.New(`no address for the vault found`) - } - - // Sanity check the value - if _, err := url.Parse(vaultAddr); err != nil { - return nil, errors.New(fmt.Sprintf("error parsing the vault address: %s", err)) - } - - // Unwrap the token - clientConf := api.DefaultConfig() - clientConf.Address = vaultAddr - client, err := api.NewClient(clientConf) - if err != nil { - return nil, errwrap.Wrapf("error during token unwrap request: {{err}}", err) - } - - secret, err := client.Logical().Unwrap(unwrapToken) - if err != nil { - return nil, errwrap.Wrapf("error during token unwrap request: {{err}}", err) - } - if secret == nil { - return nil, errors.New("error during token unwrap request secret is nil") - } - - // Retrieve and parse the server's certificate - serverCertBytesRaw, ok := secret.Data["ServerCert"].(string) - if !ok { - return nil, errors.New("error unmarshalling certificate") - } - - serverCertBytes, err := base64.StdEncoding.DecodeString(serverCertBytesRaw) - if err != nil { - return nil, fmt.Errorf("error parsing certificate: %v", err) +func VaultPluginTLSProvider(apiTLSConfig *api.TLSConfig) func() (*tls.Config, error) { + return func() (*tls.Config, error) { + unwrapToken := os.Getenv(PluginUnwrapTokenEnv) + + // Ensure unwrap token is a JWT + if strings.Count(unwrapToken, ".") != 2 { + return nil, errors.New("Could not parse unwraptoken") + } + + // Parse the JWT and retrieve the vault address + wt, err := jws.ParseJWT([]byte(unwrapToken)) + if err != nil { + return nil, errors.New(fmt.Sprintf("error decoding token: %s", err)) + } + if wt == nil { + return nil, errors.New("nil decoded token") + } + + addrRaw := wt.Claims().Get("addr") + if addrRaw == nil { + return nil, errors.New("decoded token does not contain primary cluster address") + } + vaultAddr, ok := addrRaw.(string) + if !ok { + return nil, errors.New("decoded token's address not valid") + } + if vaultAddr == "" { + return nil, errors.New(`no address for the vault found`) + } + + // Sanity check the value + if _, err := url.Parse(vaultAddr); err != nil { + return nil, errors.New(fmt.Sprintf("error parsing the vault address: %s", err)) + } + + // Unwrap the token + clientConf := api.DefaultConfig() + clientConf.Address = vaultAddr + if apiTLSConfig != nil { + clientConf.ConfigureTLS(apiTLSConfig) + } + client, err := api.NewClient(clientConf) + if err != nil { + return nil, errwrap.Wrapf("error during api client creation: {{err}}", err) + } + + secret, err := client.Logical().Unwrap(unwrapToken) + if err != nil { + return nil, errwrap.Wrapf("error during token unwrap request: {{err}}", err) + } + if secret == nil { + return nil, errors.New("error during token unwrap request secret is nil") + } + + // Retrieve and parse the server's certificate + serverCertBytesRaw, ok := secret.Data["ServerCert"].(string) + if !ok { + return nil, errors.New("error unmarshalling certificate") + } + + serverCertBytes, err := base64.StdEncoding.DecodeString(serverCertBytesRaw) + if err != nil { + return nil, fmt.Errorf("error parsing certificate: %v", err) + } + + serverCert, err := x509.ParseCertificate(serverCertBytes) + if err != nil { + return nil, fmt.Errorf("error parsing certificate: %v", err) + } + + // Retrieve and parse the server's private key + serverKeyB64, ok := secret.Data["ServerKey"].(string) + if !ok { + return nil, errors.New("error unmarshalling certificate") + } + + serverKeyRaw, err := base64.StdEncoding.DecodeString(serverKeyB64) + if err != nil { + return nil, fmt.Errorf("error parsing certificate: %v", err) + } + + serverKey, err := x509.ParseECPrivateKey(serverKeyRaw) + if err != nil { + return nil, fmt.Errorf("error parsing certificate: %v", err) + } + + // Add CA cert to the cert pool + caCertPool := x509.NewCertPool() + caCertPool.AddCert(serverCert) + + // Build a certificate object out of the server's cert and private key. + cert := tls.Certificate{ + Certificate: [][]byte{serverCertBytes}, + PrivateKey: serverKey, + Leaf: serverCert, + } + + // Setup TLS config + tlsConfig := &tls.Config{ + ClientCAs: caCertPool, + RootCAs: caCertPool, + ClientAuth: tls.RequireAndVerifyClientCert, + // TLS 1.2 minimum + MinVersion: tls.VersionTLS12, + Certificates: []tls.Certificate{cert}, + } + tlsConfig.BuildNameToCertificate() + + return tlsConfig, nil } - - serverCert, err := x509.ParseCertificate(serverCertBytes) - if err != nil { - return nil, fmt.Errorf("error parsing certificate: %v", err) - } - - // Retrieve and parse the server's private key - serverKeyB64, ok := secret.Data["ServerKey"].(string) - if !ok { - return nil, errors.New("error unmarshalling certificate") - } - - serverKeyRaw, err := base64.StdEncoding.DecodeString(serverKeyB64) - if err != nil { - return nil, fmt.Errorf("error parsing certificate: %v", err) - } - - serverKey, err := x509.ParseECPrivateKey(serverKeyRaw) - if err != nil { - return nil, fmt.Errorf("error parsing certificate: %v", err) - } - - // Add CA cert to the cert pool - caCertPool := x509.NewCertPool() - caCertPool.AddCert(serverCert) - - // Build a certificate object out of the server's cert and private key. - cert := tls.Certificate{ - Certificate: [][]byte{serverCertBytes}, - PrivateKey: serverKey, - Leaf: serverCert, - } - - // Setup TLS config - tlsConfig := &tls.Config{ - ClientCAs: caCertPool, - RootCAs: caCertPool, - ClientAuth: tls.RequireAndVerifyClientCert, - // TLS 1.2 minimum - MinVersion: tls.VersionTLS12, - Certificates: []tls.Certificate{cert}, - } - tlsConfig.BuildNameToCertificate() - - return tlsConfig, nil } diff --git a/plugins/database/cassandra/cassandra-database-plugin/main.go b/plugins/database/cassandra/cassandra-database-plugin/main.go index 79f0e0dbe94e..bb3f44142195 100644 --- a/plugins/database/cassandra/cassandra-database-plugin/main.go +++ b/plugins/database/cassandra/cassandra-database-plugin/main.go @@ -4,11 +4,16 @@ import ( "fmt" "os" + "github.com/hashicorp/vault/helper/pluginutil" "github.com/hashicorp/vault/plugins/database/cassandra" ) func main() { - err := cassandra.Run() + apiClientMeta := &pluginutil.APIClientMeta{} + flags := apiClientMeta.FlagSet() + flags.Parse(os.Args) + + err := cassandra.Run(apiClientMeta.GetTLSConfig()) if err != nil { fmt.Println(err) os.Exit(1) diff --git a/plugins/database/cassandra/cassandra.go b/plugins/database/cassandra/cassandra.go index bf1cbab92cf4..60e445ff6bae 100644 --- a/plugins/database/cassandra/cassandra.go +++ b/plugins/database/cassandra/cassandra.go @@ -6,8 +6,10 @@ import ( "github.com/gocql/gocql" multierror "github.com/hashicorp/go-multierror" + "github.com/hashicorp/vault/api" "github.com/hashicorp/vault/builtin/logical/database/dbplugin" "github.com/hashicorp/vault/helper/strutil" + "github.com/hashicorp/vault/plugins" "github.com/hashicorp/vault/plugins/helper/database/connutil" "github.com/hashicorp/vault/plugins/helper/database/credsutil" "github.com/hashicorp/vault/plugins/helper/database/dbutil" @@ -41,13 +43,13 @@ func New() (interface{}, error) { } // Run instantiates a Cassandra object, and runs the RPC server for the plugin -func Run() error { +func Run(apiTLSConfig *api.TLSConfig) error { dbType, err := New() if err != nil { return err } - dbplugin.NewPluginServer(dbType.(*Cassandra)) + plugins.Serve(dbType.(*Cassandra), apiTLSConfig) return nil } diff --git a/plugins/database/mssql/mssql-database-plugin/main.go b/plugins/database/mssql/mssql-database-plugin/main.go index ead1cf842306..d52fd13db0db 100644 --- a/plugins/database/mssql/mssql-database-plugin/main.go +++ b/plugins/database/mssql/mssql-database-plugin/main.go @@ -4,11 +4,16 @@ import ( "fmt" "os" + "github.com/hashicorp/vault/helper/pluginutil" "github.com/hashicorp/vault/plugins/database/mssql" ) func main() { - err := mssql.Run() + apiClientMeta := &pluginutil.APIClientMeta{} + flags := apiClientMeta.FlagSet() + flags.Parse(os.Args) + + err := mssql.Run(apiClientMeta.GetTLSConfig()) if err != nil { fmt.Println(err) os.Exit(1) diff --git a/plugins/database/mssql/mssql.go b/plugins/database/mssql/mssql.go index d82efce6f729..9b22aa87cdfa 100644 --- a/plugins/database/mssql/mssql.go +++ b/plugins/database/mssql/mssql.go @@ -6,8 +6,10 @@ import ( "strings" "time" + "github.com/hashicorp/vault/api" "github.com/hashicorp/vault/builtin/logical/database/dbplugin" "github.com/hashicorp/vault/helper/strutil" + "github.com/hashicorp/vault/plugins" "github.com/hashicorp/vault/plugins/helper/database/connutil" "github.com/hashicorp/vault/plugins/helper/database/credsutil" "github.com/hashicorp/vault/plugins/helper/database/dbutil" @@ -39,13 +41,13 @@ func New() (interface{}, error) { } // Run instantiates a MSSQL object, and runs the RPC server for the plugin -func Run() error { +func Run(apiTLSConfig *api.TLSConfig) error { dbType, err := New() if err != nil { return err } - dbplugin.Serve(dbType.(*MSSQL)) + plugins.Serve(dbType.(*MSSQL), apiTLSConfig) return nil } diff --git a/plugins/database/mysql/mysql-database-plugin/main.go b/plugins/database/mysql/mysql-database-plugin/main.go index c0ec75c9cdc0..a9389f50420d 100644 --- a/plugins/database/mysql/mysql-database-plugin/main.go +++ b/plugins/database/mysql/mysql-database-plugin/main.go @@ -4,11 +4,16 @@ import ( "fmt" "os" + "github.com/hashicorp/vault/helper/pluginutil" "github.com/hashicorp/vault/plugins/database/mysql" ) func main() { - err := mysql.Run() + apiClientMeta := &pluginutil.APIClientMeta{} + flags := apiClientMeta.FlagSet() + flags.Parse(os.Args) + + err := mysql.Run(apiClientMeta.GetTLSConfig()) if err != nil { fmt.Println(err) os.Exit(1) diff --git a/plugins/database/mysql/mysql.go b/plugins/database/mysql/mysql.go index 7eb680759f3e..7a44d7341f1a 100644 --- a/plugins/database/mysql/mysql.go +++ b/plugins/database/mysql/mysql.go @@ -5,8 +5,10 @@ import ( "strings" "time" + "github.com/hashicorp/vault/api" "github.com/hashicorp/vault/builtin/logical/database/dbplugin" "github.com/hashicorp/vault/helper/strutil" + "github.com/hashicorp/vault/plugins" "github.com/hashicorp/vault/plugins/helper/database/connutil" "github.com/hashicorp/vault/plugins/helper/database/credsutil" "github.com/hashicorp/vault/plugins/helper/database/dbutil" @@ -42,13 +44,13 @@ func New() (interface{}, error) { } // Run instantiates a MySQL object, and runs the RPC server for the plugin -func Run() error { +func Run(apiTLSConfig *api.TLSConfig) error { dbType, err := New() if err != nil { return err } - dbplugin.Serve(dbType.(*MySQL)) + plugins.Serve(dbType.(*MySQL), apiTLSConfig) return nil } diff --git a/plugins/database/postgresql/postgresql-database-plugin/main.go b/plugins/database/postgresql/postgresql-database-plugin/main.go index 9b9b813c4c19..e6acb0584748 100644 --- a/plugins/database/postgresql/postgresql-database-plugin/main.go +++ b/plugins/database/postgresql/postgresql-database-plugin/main.go @@ -4,11 +4,16 @@ import ( "fmt" "os" + "github.com/hashicorp/vault/helper/pluginutil" "github.com/hashicorp/vault/plugins/database/postgresql" ) func main() { - err := postgresql.Run() + apiClientMeta := &pluginutil.APIClientMeta{} + flags := apiClientMeta.FlagSet() + flags.Parse(os.Args) + + err := postgresql.Run(apiClientMeta.GetTLSConfig()) if err != nil { fmt.Println(err) os.Exit(1) diff --git a/plugins/database/postgresql/postgresql.go b/plugins/database/postgresql/postgresql.go index bc5b14544d22..d60ef8bbe00c 100644 --- a/plugins/database/postgresql/postgresql.go +++ b/plugins/database/postgresql/postgresql.go @@ -6,8 +6,10 @@ import ( "strings" "time" + "github.com/hashicorp/vault/api" "github.com/hashicorp/vault/builtin/logical/database/dbplugin" "github.com/hashicorp/vault/helper/strutil" + "github.com/hashicorp/vault/plugins" "github.com/hashicorp/vault/plugins/helper/database/connutil" "github.com/hashicorp/vault/plugins/helper/database/credsutil" "github.com/hashicorp/vault/plugins/helper/database/dbutil" @@ -35,13 +37,13 @@ func New() (interface{}, error) { } // Run instantiates a PostgreSQL object, and runs the RPC server for the plugin -func Run() error { +func Run(apiTLSConfig *api.TLSConfig) error { dbType, err := New() if err != nil { return err } - dbplugin.Serve(dbType.(*PostgreSQL)) + plugins.Serve(dbType.(*PostgreSQL), apiTLSConfig) return nil } diff --git a/plugins/serve.go b/plugins/serve.go new file mode 100644 index 000000000000..263b301f7b06 --- /dev/null +++ b/plugins/serve.go @@ -0,0 +1,31 @@ +package plugins + +import ( + "fmt" + + "github.com/hashicorp/vault/api" + "github.com/hashicorp/vault/builtin/logical/database/dbplugin" + "github.com/hashicorp/vault/helper/pluginutil" +) + +// Serve is used to start a plugin's RPC server. It takes an interface that must +// implement a known plugin interface to vault and an optional api.TLSConfig for +// use during the inital unwrap request to vault. The api config is particulary +// useful when vault is setup to require client cert checking. +func Serve(plugin interface{}, tlsConfig *api.TLSConfig) { + tlsProvider := pluginutil.VaultPluginTLSProvider(tlsConfig) + + err := pluginutil.OptionallyEnableMlock() + if err != nil { + fmt.Println(err) + return + } + + switch p := plugin.(type) { + case dbplugin.Database: + dbplugin.Serve(p, tlsProvider) + default: + fmt.Println("Unsuported plugin type") + } + +} diff --git a/vault/testing.go b/vault/testing.go index b2fe36b332f6..36bbb1276d23 100644 --- a/vault/testing.go +++ b/vault/testing.go @@ -790,6 +790,7 @@ func TestCluster(t testing.TB, handlers []http.Handler, base *CoreConfig, unseal if err != nil { t.Fatalf("err: %v", err) } + c1.redirectAddr = coreConfig.RedirectAddr coreConfig.RedirectAddr = fmt.Sprintf("https://127.0.0.1:%d", c2lns[0].Address.Port) if coreConfig.ClusterAddr != "" { @@ -799,6 +800,7 @@ func TestCluster(t testing.TB, handlers []http.Handler, base *CoreConfig, unseal if err != nil { t.Fatalf("err: %v", err) } + c2.redirectAddr = coreConfig.RedirectAddr coreConfig.RedirectAddr = fmt.Sprintf("https://127.0.0.1:%d", c3lns[0].Address.Port) if coreConfig.ClusterAddr != "" { @@ -808,6 +810,7 @@ func TestCluster(t testing.TB, handlers []http.Handler, base *CoreConfig, unseal if err != nil { t.Fatalf("err: %v", err) } + c2.redirectAddr = coreConfig.RedirectAddr // // Clustering setup From 30a02eded010c222f9e7e6430955f5fb5c190207 Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Tue, 2 May 2017 14:44:14 -0700 Subject: [PATCH 127/152] Don't need to explictly set redirectAddrs --- vault/testing.go | 3 --- 1 file changed, 3 deletions(-) diff --git a/vault/testing.go b/vault/testing.go index 36bbb1276d23..b2fe36b332f6 100644 --- a/vault/testing.go +++ b/vault/testing.go @@ -790,7 +790,6 @@ func TestCluster(t testing.TB, handlers []http.Handler, base *CoreConfig, unseal if err != nil { t.Fatalf("err: %v", err) } - c1.redirectAddr = coreConfig.RedirectAddr coreConfig.RedirectAddr = fmt.Sprintf("https://127.0.0.1:%d", c2lns[0].Address.Port) if coreConfig.ClusterAddr != "" { @@ -800,7 +799,6 @@ func TestCluster(t testing.TB, handlers []http.Handler, base *CoreConfig, unseal if err != nil { t.Fatalf("err: %v", err) } - c2.redirectAddr = coreConfig.RedirectAddr coreConfig.RedirectAddr = fmt.Sprintf("https://127.0.0.1:%d", c3lns[0].Address.Port) if coreConfig.ClusterAddr != "" { @@ -810,7 +808,6 @@ func TestCluster(t testing.TB, handlers []http.Handler, base *CoreConfig, unseal if err != nil { t.Fatalf("err: %v", err) } - c2.redirectAddr = coreConfig.RedirectAddr // // Clustering setup From 6e7696b84001a444e710af3e039e054db26c9122 Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Tue, 2 May 2017 14:52:48 -0700 Subject: [PATCH 128/152] Remove unused TestCoreUnsealedWithListener function --- vault/testing.go | 12 ------------ 1 file changed, 12 deletions(-) diff --git a/vault/testing.go b/vault/testing.go index b2fe36b332f6..a8c1f16bdc92 100644 --- a/vault/testing.go +++ b/vault/testing.go @@ -234,18 +234,6 @@ func TestCoreUnsealedBackend(t testing.TB, backend physical.Backend) (*Core, [][ return core, keys, token } -func TestCoreUnsealedWithListener(t testing.TB) (*Core, [][]byte, string, net.Listener) { - core, keys, token := TestCoreUnsealed(t) - ln, err := net.Listen("tcp", "127.0.0.1:0") - if err != nil { - t.Fatalf("err: %s", err) - } - addr := "http://" + ln.Addr().String() - core.redirectAddr = addr - - return core, keys, token, ln -} - func testTokenStore(t testing.TB, c *Core) *TokenStore { me := &MountEntry{ Table: credentialTableType, From fe86f06daf3affebdc465affe37d16998f8daa52 Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Tue, 2 May 2017 15:59:08 -0700 Subject: [PATCH 129/152] Fix a few PR comments --- helper/pluginutil/runner.go | 1 - plugins/serve.go | 2 +- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/helper/pluginutil/runner.go b/helper/pluginutil/runner.go index 91439a3b8b6e..9dbe5c51bb7b 100644 --- a/helper/pluginutil/runner.go +++ b/helper/pluginutil/runner.go @@ -106,7 +106,6 @@ func (f *APIClientMeta) FlagSet() *flag.FlagSet { fs.StringVar(&f.flagCAPath, "ca-path", "", "") fs.StringVar(&f.flagClientCert, "client-cert", "", "") fs.StringVar(&f.flagClientKey, "client-key", "", "") - fs.BoolVar(&f.flagInsecure, "insecure", false, "") fs.BoolVar(&f.flagInsecure, "tls-skip-verify", false, "") return fs diff --git a/plugins/serve.go b/plugins/serve.go index 263b301f7b06..a40fc5b14fea 100644 --- a/plugins/serve.go +++ b/plugins/serve.go @@ -25,7 +25,7 @@ func Serve(plugin interface{}, tlsConfig *api.TLSConfig) { case dbplugin.Database: dbplugin.Serve(p, tlsProvider) default: - fmt.Println("Unsuported plugin type") + fmt.Println("Unsupported plugin type") } } From dc5979e3ae2c1bb2b5a345e3689fc3380fe868b7 Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Tue, 2 May 2017 16:20:07 -0700 Subject: [PATCH 130/152] Fix wording in docs --- website/source/docs/internals/plugins.html.md | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/website/source/docs/internals/plugins.html.md b/website/source/docs/internals/plugins.html.md index a3baafff07d0..f1a720efb2c4 100644 --- a/website/source/docs/internals/plugins.html.md +++ b/website/source/docs/internals/plugins.html.md @@ -35,7 +35,8 @@ private key for it to use to talk to the original vault process. ## Plugin Registration An important consideration of Vault's plugin system is to ensure the plugin invoked by vault is authentic and maintains integrity. There are two components -that a Vault operator needs to configure before external plugins can be run. +that a Vault operator needs to configure before external plugins can be run, the +plugin directory and the plugin catalog entry. ### Plugin Directory The plugin directory is a configuration option of Vault, and can be specified in @@ -52,7 +53,7 @@ between the time of the SHA check and the time of plugin execution. ### Plugin Catalog The plugin catalog is Vault's list of approved plugins. The catalog is stored in Vault's barrier and can only be updated by a vault user with sudo permissions. -Upon adding a new plugin the plugin name, SHA256 sum of the executable, and the +Upon adding a new plugin, the plugin name, SHA256 sum of the executable, and the command that should be used to run the plugin must be provided. The catalog will make sure the executable referenced in the command exists in the plugin directory. When added to the catalog the plugin is not automatically executed, From d230446b4d9c9d4ba5eade25e19a4e5cf6adb002 Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Tue, 2 May 2017 17:04:49 -0700 Subject: [PATCH 131/152] Update docs and add cassandra as a builtin plugin --- helper/builtinplugins/builtin.go | 2 ++ website/source/docs/internals/plugins.html.md | 4 ++-- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/helper/builtinplugins/builtin.go b/helper/builtinplugins/builtin.go index c20a92603dae..3dec8588be9d 100644 --- a/helper/builtinplugins/builtin.go +++ b/helper/builtinplugins/builtin.go @@ -1,6 +1,7 @@ package builtinplugins import ( + "github.com/hashicorp/vault/plugins/database/cassandra" "github.com/hashicorp/vault/plugins/database/mssql" "github.com/hashicorp/vault/plugins/database/mysql" "github.com/hashicorp/vault/plugins/database/postgresql" @@ -12,6 +13,7 @@ var plugins map[string]BuiltinFactory = map[string]BuiltinFactory{ "mysql-database-plugin": mysql.New, "postgresql-database-plugin": postgresql.New, "mssql-database-plugin": mssql.New, + "cassandra-database-plugin": cassandra.New, } func Get(name string) (BuiltinFactory, bool) { diff --git a/website/source/docs/internals/plugins.html.md b/website/source/docs/internals/plugins.html.md index f1a720efb2c4..600bc034e2ae 100644 --- a/website/source/docs/internals/plugins.html.md +++ b/website/source/docs/internals/plugins.html.md @@ -93,11 +93,11 @@ backend running the plugin. package main import ( - plugin "github.com/hashicorp/vault/builtin/logcial/database/dbplugin" + "github.com/hashicorp/vault/plugins" ) func main() { - plugin.Serve(new(MyPlugin)) + plugins.Serve(new(MyPlugin), nil) } ``` From 60753dcf129fe81e891da301e7fa08a73738c371 Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Tue, 2 May 2017 17:19:49 -0700 Subject: [PATCH 132/152] Only wrap in tracing middleware if the logger is set to trace level --- .../database/dbplugin/databasemiddleware.go | 50 ++++++++----------- builtin/logical/database/dbplugin/plugin.go | 10 ++-- 2 files changed, 26 insertions(+), 34 deletions(-) diff --git a/builtin/logical/database/dbplugin/databasemiddleware.go b/builtin/logical/database/dbplugin/databasemiddleware.go index 13591e51628d..83f57ef87e80 100644 --- a/builtin/logical/database/dbplugin/databasemiddleware.go +++ b/builtin/logical/database/dbplugin/databasemiddleware.go @@ -23,57 +23,47 @@ func (mw *databaseTracingMiddleware) Type() (string, error) { } func (mw *databaseTracingMiddleware) CreateUser(statements Statements, usernamePrefix string, expiration time.Time) (username string, password string, err error) { - if mw.logger.IsTrace() { - defer func(then time.Time) { - mw.logger.Trace("database", "operation", "CreateUser", "status", "finished", "type", mw.typeStr, "err", err, "took", time.Since(then)) - }(time.Now()) + defer func(then time.Time) { + mw.logger.Trace("database", "operation", "CreateUser", "status", "finished", "type", mw.typeStr, "err", err, "took", time.Since(then)) + }(time.Now()) - mw.logger.Trace("database", "operation", "CreateUser", "status", "started", "type", mw.typeStr) - } + mw.logger.Trace("database", "operation", "CreateUser", "status", "started", "type", mw.typeStr) return mw.next.CreateUser(statements, usernamePrefix, expiration) } func (mw *databaseTracingMiddleware) RenewUser(statements Statements, username string, expiration time.Time) (err error) { - if mw.logger.IsTrace() { - defer func(then time.Time) { - mw.logger.Trace("database", "operation", "RenewUser", "status", "finished", "type", mw.typeStr, "err", err, "took", time.Since(then)) - }(time.Now()) + defer func(then time.Time) { + mw.logger.Trace("database", "operation", "RenewUser", "status", "finished", "type", mw.typeStr, "err", err, "took", time.Since(then)) + }(time.Now()) - mw.logger.Trace("database", "operation", "RenewUser", "status", "started", mw.typeStr) - } + mw.logger.Trace("database", "operation", "RenewUser", "status", "started", mw.typeStr) return mw.next.RenewUser(statements, username, expiration) } func (mw *databaseTracingMiddleware) RevokeUser(statements Statements, username string) (err error) { - if mw.logger.IsTrace() { - defer func(then time.Time) { - mw.logger.Trace("database", "operation", "RevokeUser", "status", "finished", "type", mw.typeStr, "err", err, "took", time.Since(then)) - }(time.Now()) + defer func(then time.Time) { + mw.logger.Trace("database", "operation", "RevokeUser", "status", "finished", "type", mw.typeStr, "err", err, "took", time.Since(then)) + }(time.Now()) - mw.logger.Trace("database", "operation", "RevokeUser", "status", "started", "type", mw.typeStr) - } + mw.logger.Trace("database", "operation", "RevokeUser", "status", "started", "type", mw.typeStr) return mw.next.RevokeUser(statements, username) } func (mw *databaseTracingMiddleware) Initialize(conf map[string]interface{}, verifyConnection bool) (err error) { - if mw.logger.IsTrace() { - defer func(then time.Time) { - mw.logger.Trace("database", "operation", "Initialize", "status", "finished", "type", mw.typeStr, "verify", verifyConnection, "err", err, "took", time.Since(then)) - }(time.Now()) + defer func(then time.Time) { + mw.logger.Trace("database", "operation", "Initialize", "status", "finished", "type", mw.typeStr, "verify", verifyConnection, "err", err, "took", time.Since(then)) + }(time.Now()) - mw.logger.Trace("database", "operation", "Initialize", "status", "started", "type", mw.typeStr) - } + mw.logger.Trace("database", "operation", "Initialize", "status", "started", "type", mw.typeStr) return mw.next.Initialize(conf, verifyConnection) } func (mw *databaseTracingMiddleware) Close() (err error) { - if mw.logger.IsTrace() { - defer func(then time.Time) { - mw.logger.Trace("database", "operation", "Close", "status", "finished", "type", mw.typeStr, "err", err, "took", time.Since(then)) - }(time.Now()) + defer func(then time.Time) { + mw.logger.Trace("database", "operation", "Close", "status", "finished", "type", mw.typeStr, "err", err, "took", time.Since(then)) + }(time.Now()) - mw.logger.Trace("database", "operation", "Close", "status", "started", "type", mw.typeStr) - } + mw.logger.Trace("database", "operation", "Close", "status", "started", "type", mw.typeStr) return mw.next.Close() } diff --git a/builtin/logical/database/dbplugin/plugin.go b/builtin/logical/database/dbplugin/plugin.go index 941f7aa04d45..bc63594ae8c0 100644 --- a/builtin/logical/database/dbplugin/plugin.go +++ b/builtin/logical/database/dbplugin/plugin.go @@ -73,10 +73,12 @@ func PluginFactory(pluginName string, sys pluginutil.LookRunnerUtil, logger log. } // Wrap with tracing middleware - db = &databaseTracingMiddleware{ - next: db, - typeStr: typeStr, - logger: logger, + if logger.IsTrace() { + db = &databaseTracingMiddleware{ + next: db, + typeStr: typeStr, + logger: logger, + } } return db, nil From 2be2e4c74e96502a946203219ee8eff8ca103e67 Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Tue, 2 May 2017 22:24:31 -0700 Subject: [PATCH 133/152] Update docs for the database backend and it's plugins --- .../docs/secrets/databases/cassandra.html.md | 53 +++++++++++ .../docs/secrets/databases/index.html.md | 88 ++++++++++++++++++- .../docs/secrets/databases/mssql.html.md | 51 +++++++++++ .../secrets/databases/mysql-maria.html.md | 49 +++++++++++ .../docs/secrets/databases/postgresql.html.md | 51 +++++++++++ 5 files changed, 291 insertions(+), 1 deletion(-) diff --git a/website/source/docs/secrets/databases/cassandra.html.md b/website/source/docs/secrets/databases/cassandra.html.md index 012e7db5b7d2..99d3d3bf9765 100644 --- a/website/source/docs/secrets/databases/cassandra.html.md +++ b/website/source/docs/secrets/databases/cassandra.html.md @@ -7,3 +7,56 @@ description: |- --- # Cassandra Database Plugin + +Name: `cassandra-database-plugin` + +The Cassandra Database Plugin is one of the supported plugins for the Database +backend. This plugin generates database credentials dynamically based on +configured roles for the Cassandra database. + +See the [Database Backend](/docs/secret/database/index.html) docs for more +information about setting up the Database Backend. + +## Quick Start + +After the Database Backend is mounted you can configure a cassandra connection +by specifying this plugin as the `"plugin_name"` argument. Here is an example +cassandra configuration: + +``` +$ vault write database/config/cassandra \ + plugin_name=cassandra-database-plugin \ + allowed_roles="readonly" \ + hosts=localhost \ + username=cassandra \ + password=cassandra + +The following warnings were returned from the Vault server: +* Read access to this endpoint should be controlled via ACLs as it will return the connection details as is, including passwords, if any. +``` + +Once the cassandra connection is configured we can add a role: + +``` +$ vault write database/roles/readonly \ + db_name=cassandra \ + creation_statements="CREATE USER '{{username}}' WITH PASSWORD '{{password}}' NOSUPERUSER; \ + GRANT SELECT ON ALL KEYSPACES TO {{username}};" \ + default_ttl="1h" \ + max_ttl="24h" + + +Success! Data written to: database/roles/readonly +``` + +This role can be used to retrieve a new set of credentials by querying the +"database/creds/readonly" endpoint. + +## API + +The full list of configurable options can be seen in the [Cassandra database +plugin API](/api/secret/database/cassandra.html) page. + +Or for more information on the Database secret backend's HTTP API please see the [Database secret +backend API](/api/secret/database/index.html). + diff --git a/website/source/docs/secrets/databases/index.html.md b/website/source/docs/secrets/databases/index.html.md index 20f7bbed2725..cf366d9c90ef 100644 --- a/website/source/docs/secrets/databases/index.html.md +++ b/website/source/docs/secrets/databases/index.html.md @@ -8,4 +8,90 @@ description: |- # Databases -Something +Name: `Database` + +The Database secret backend for Vault generates database credentials dynamically +based on configured roles. It works with a number of different databases through +a plugin interface. There are a number of builtin database types and an exposed +framework for running custom database types for extendability. This means that +services that need to access a database no longer need to hardcode credentials: +they can request them from Vault, and use Vault's leasing mechanism to more +easily roll keys. + +Additionally, it introduces a new ability: with every service accessing the +database with unique credentials, it makes auditing much easier when +questionable data access is discovered: you can track it down to the specific +instance of a service based on the SQL username. + +Vault makes use of its own internal revocation system to ensure that users +become invalid within a reasonable time of the lease expiring. + +This page will show a quick start for this backend. For detailed documentation +on every path, use vault path-help after mounting the backend. + +## Quick Start + +The first step in using the Database backend is mounting it. + +```text +$ vault mount database +Successfully mounted 'database' at 'database'! +``` + +Next, we must configure this backend to connect to a database. In this example +we will connect to a MySQL database, but the configuration details needed for +other plugin types can be found in their docs pages. This backend can configure +multiple database connections, therefore a name for the connection must be +provide; we'll call this one simply "mysql". + +``` +$ vault write database/config/mysql \ + plugin_name=mysql-database-plugin \ + connection_url="root:mysql@tcp(127.0.0.1:3306)/" \ + allowed_roles="readonly" + +The following warnings were returned from the Vault server: +* Read access to this endpoint should be controlled via ACLs as it will return the connection details as is, including passwords, if any. +``` + +The next step is to configure a role. A role is a logical name that maps to a +policy used to generate those credentials. A role needs to be configured with +the database name we created above, and the default/max TTLs. For example, lets +create a "readonly" role: + +``` +$ vault write database/roles/readonly \ + db_name=mysql \ + creation_statements="CREATE USER '{{name}}'@'%' IDENTIFIED BY '{{password}}';GRANT SELECT ON *.* TO '{{name}}'@'%';" \ + default_ttl="1h" \ + max_ttl="24h" +Success! Data written to: database/roles/readonly +``` +By writing to the roles/readonly path we are defining the readonly role. This +role will be created by evaluating the given creation statements. By default, +the {{name}} and {{password}} fields will be populated by the plugin with +dynamically generated values. In other plugins the {{expiration}} field could +also be supported. This SQL statement is creating the named user, and then +granting it SELECT or read-only privileges to tables in the database. More +complex GRANT queries can be used to customize the privileges of the role. +Custom revocation statements could be passed too, but this plugin has a default +statement we can use. + +To generate a new set of credentials, we simply read from that role: + +``` +$ vault read database/creds/readonly +Key Value +--- ----- +lease_id database/creds/readonly/2f6a614c-4aa2-7b19-24b9-ad944a8d4de6 +lease_duration 1h0m0s +lease_renewable true +password 8cab931c-d62e-a73d-60d3-5ee85139cd66 +username v-root-e2978cd0- +``` + +## API + +The Database secret backend has a full HTTP API. Please see the [Database secret +backend API](/api/secret/database/index.html) for more details. + diff --git a/website/source/docs/secrets/databases/mssql.html.md b/website/source/docs/secrets/databases/mssql.html.md index 32ecf7775eaf..2d220ddb87e5 100644 --- a/website/source/docs/secrets/databases/mssql.html.md +++ b/website/source/docs/secrets/databases/mssql.html.md @@ -7,3 +7,54 @@ description: |- --- # MSSQL Database Plugin + +Name: `mssql-database-plugin` + +The MSSQL Database Plugin is one of the supported plugins for the Database +backend. This plugin generates database credentials dynamically based on +configured roles for the MSSQL database. + +See the [Database Backend](/docs/secret/database/index.html) docs for more +information about setting up the Database Backend. + +## Quick Start + +After the Database Backend is mounted you can configure a MSSQL connection +by specifying this plugin as the `"plugin_name"` argument. Here is an example +configuration: + +``` +$ vault write database/config/mssql \ + plugin_name=mssql-database-plugin \ + connection_url='sqlserver://sa:yourStrong(!)Password@localhost:1433' \ + allowed_roles="readonly" + +The following warnings were returned from the Vault server: +* Read access to this endpoint should be controlled via ACLs as it will return the connection details as is, including passwords, if any. +``` + +Once the MSSQL connection is configured we can add a role: + +``` +$ vault write database/roles/readonly \ + db_name=mssql \ + creation_statements="CREATE LOGIN [{{name}}] WITH PASSWORD = '{{password}}';\ + USE AdventureWorks; CREATE USER [{{name}}] FOR LOGIN [{{name}}]; \ + GRANT SELECT ON SCHEMA::dbo TO [{{name}}];" \ + default_ttl="1h" \ + max_ttl="24h" + +Success! Data written to: database/roles/readonly +``` + +This role can now be used to retrieve a new set of credentials by querying the +"database/creds/readonly" endpoint. + +## API + +The full list of configurable options can be seen in the [MSSQL database +plugin API](/api/secret/database/mssql.html) page. + +Or for more information on the Database secret backend's HTTP API please see the [Database secret +backend API](/api/secret/database/index.html). + diff --git a/website/source/docs/secrets/databases/mysql-maria.html.md b/website/source/docs/secrets/databases/mysql-maria.html.md index 1ee601dbdbe3..bd61cc43b530 100644 --- a/website/source/docs/secrets/databases/mysql-maria.html.md +++ b/website/source/docs/secrets/databases/mysql-maria.html.md @@ -7,3 +7,52 @@ description: |- --- # MySQL/MariaDB Database Plugin + +Name: `mysql-database-plugin` + +The MySQL Database Plugin is one of the supported plugins for the Database +backend. This plugin generates database credentials dynamically based on +configured roles for the MySQL database. + +See the [Database Backend](/docs/secret/database/index.html) docs for more +information about setting up the Database Backend. + +## Quick Start + +After the Database Backend is mounted you can configure a MySQL connection +by specifying this plugin as the `"plugin_name"` argument. Here is an example +configuration: + +``` +$ vault write database/config/mysql \ + plugin_name=mysql-database-plugin \ + connection_url="root:mysql@tcp(127.0.0.1:3306)/" \ + allowed_roles="readonly" + +The following warnings were returned from the Vault server: +* Read access to this endpoint should be controlled via ACLs as it will return the connection details as is, including passwords, if any. +``` + +Once the MySQL connection is configured we can add a role: + +``` +$ vault write database/roles/readonly \ + db_name=mysql \ + creation_statements="CREATE USER '{{name}}'@'%' IDENTIFIED BY '{{password}}';GRANT SELECT ON *.* TO '{{name}}'@'%';" \ + default_ttl="1h" \ + max_ttl="24h" + +Success! Data written to: database/roles/readonly +``` + +This role can now be used to retrieve a new set of credentials by querying the +"database/creds/readonly" endpoint. + +## API + +The full list of configurable options can be seen in the [MySQL database +plugin API](/api/secret/database/mysql.html) page. + +Or for more information on the Database secret backend's HTTP API please see the [Database secret +backend API](/api/secret/database/index.html). + diff --git a/website/source/docs/secrets/databases/postgresql.html.md b/website/source/docs/secrets/databases/postgresql.html.md index 5de340043d72..e5fee10ef05a 100644 --- a/website/source/docs/secrets/databases/postgresql.html.md +++ b/website/source/docs/secrets/databases/postgresql.html.md @@ -7,3 +7,54 @@ description: |- --- # PostgreSQL Database Plugin + +Name: `postgresql-database-plugin` + +The PostgreSQL Database Plugin is one of the supported plugins for the Database +backend. This plugin generates database credentials dynamically based on +configured roles for the PostgreSQL database. + +See the [Database Backend](/docs/secret/database/index.html) docs for more +information about setting up the Database Backend. + +## Quick Start + +After the Database Backend is mounted you can configure a PostgreSQL connection +by specifying this plugin as the `"plugin_name"` argument. Here is an example +configuration: + +``` +$ vault write database/config/postgresql \ + plugin_name=postgresql-database-plugin \ + allowed_roles="readonly" \ + connection_url="postgresql://root:root@localhost:5432/postgres" + +The following warnings were returned from the Vault server: +* Read access to this endpoint should be controlled via ACLs as it will return the connection details as is, including passwords, if any. +``` + +Once the PostgreSQL connection is configured we can add a role. The PostgreSQL +plugin replaces `{{expiration}}` in statements with a formated timestamp: + +``` +$ vault write database/roles/readonly \ + db_name=postgresql \ + creation_statements="CREATE ROLE \"{{name}}\" WITH LOGIN PASSWORD '{{password}}' VALID UNTIL '{{expiration}}'; \ + GRANT SELECT ON ALL TABLES IN SCHEMA public TO \"{{name}}\";" \ + default_ttl="1h" \ + max_ttl="24h" + +Success! Data written to: database/roles/readonly +``` + +This role can be used to retrieve a new set of credentials by querying the +"database/creds/readonly" endpoint. + +## API + +The full list of configurable options can be seen in the [PostgreSQL database +plugin API](/api/secret/database/postgresql.html) page. + +Or for more information on the Database secret backend's HTTP API please see the [Database secret +backend API](/api/secret/database/index.html). + From 85967cb5a837de0e407eea56adac88a4d2b81981 Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Wed, 3 May 2017 00:01:28 -0700 Subject: [PATCH 134/152] Add custom plugins docs page --- .../docs/secrets/databases/cassandra.html.md | 2 +- .../docs/secrets/databases/custom.html.md | 120 ++++++++++++++++++ .../docs/secrets/databases/mssql.html.md | 2 +- .../secrets/databases/mysql-maria.html.md | 2 +- .../docs/secrets/databases/postgresql.html.md | 2 +- website/source/layouts/docs.erb | 3 + 6 files changed, 127 insertions(+), 4 deletions(-) create mode 100644 website/source/docs/secrets/databases/custom.html.md diff --git a/website/source/docs/secrets/databases/cassandra.html.md b/website/source/docs/secrets/databases/cassandra.html.md index 99d3d3bf9765..b3d87f7ed238 100644 --- a/website/source/docs/secrets/databases/cassandra.html.md +++ b/website/source/docs/secrets/databases/cassandra.html.md @@ -58,5 +58,5 @@ The full list of configurable options can be seen in the [Cassandra database plugin API](/api/secret/database/cassandra.html) page. Or for more information on the Database secret backend's HTTP API please see the [Database secret -backend API](/api/secret/database/index.html). +backend API](/api/secret/database/index.html) page. diff --git a/website/source/docs/secrets/databases/custom.html.md b/website/source/docs/secrets/databases/custom.html.md new file mode 100644 index 000000000000..5911e1d809c2 --- /dev/null +++ b/website/source/docs/secrets/databases/custom.html.md @@ -0,0 +1,120 @@ +--- +layout: "docs" +page_title: "Custom Database Plugins" +sidebar_current: "docs-secrets-databases-custom" +description: |- + Creating custom database plugins for Vault's Database backend to generate credentials for a database. +--- + +# Custom Database Plugins + +The Database backend allows new functionality to be added through a plugin +interface without needing to modify vault's core code. This allows you write +your own code to generate credentials in any database you wish. It also allows +databases that require dynamically linked libraries to be used with vault. + +~> **Advanced topic!** Plugin development is a highly advanced +topic in Vault, and is not required knowledge for day-to-day usage. +If you don't plan on writing any plugins, we recommend not reading +this section of the documentation. + +Please read the [Plugins internals](/docs/internals/plugins.html) docs for more +information about the plugin system before getting started building your +Database plugin. + +## Plugin Interface + +All plugins for the Database backend must implement the same simple interface. + +```go +type Database interface { + Type() (string, error) + CreateUser(statements Statements, usernamePrefix string, expiration time.Time) (username string, password string, err error) + RenewUser(statements Statements, username string, expiration time.Time) error + RevokeUser(statements Statements, username string) error + + Initialize(config map[string]interface{}, verifyConnection bool) error + Close() error +} +``` + +You'll notice the first parameter to a number of those functions is a +`Statements` struct. This struct is used to pass the Role's configured +statements to the plugin on function call. The struct is defined as: + +```go +type Statements struct { + CreationStatements string + RevocationStatements string + RollbackStatements string + RenewStatements string +} +``` + +It is up to your plugin to replace the `{{name}}`, `{{password}}`, and +`{{expiration}}` in these statements with the proper vaules. + +The `Initialize` function is passed a map of keys to values, this data is what the +user specified as the configuration for the plugin. Your plugin should use this +data to make connections to the database. It is also passed a boolean value +specifying whether or not your plugin should return an error if it is unable to +connect to the database. + +## Serving your plugin + +Once your plugin is built you should pass it to vault's `plugins` package by +calling the `Serve` method: + +```go +package main + +import ( + "github.com/hashicorp/vault/plugins" +) + +func main() { + plugins.Serve(new(MyPlugin), nil) +} +``` + +Replacing `MyPlugin` with the actual implementation of your plugin. + +The second parameter to `Serve` takes in an optional vault `api.TLSConfig` for +configuring the plugin to communicate with vault for the initial unwrap call. +This if useful if your vault setup requires client certificate checks. This +config wont be used once the plugin unwraps its own TLS cert and key. + +## Running your plugin + +The above main package, once built, will supply you with a binary of your +plugin. We also recommend if you are planning on distributing your plugin to +build with [gox](https://github.com/mitchellh/gox) for cross platform builds. + +To use your plugin with the Database backend you need to place the binary in the +plugin directory as specified in the [plugin internals](/docs/internals/plugins.html) docs. + +You should now be able to register your plugin into the vault catalog. To do +this your token will need sudo permissions. + +``` +$ vault write sys/plugins/catalog/myplugin-database-plugin \ + sha_256= \ + command="myplugin" +Success! Data written to: sys/plugins/catalog/myplugin-database-plugin +``` + +Now you should be able to configure your plugin like any other: + +``` +$ vault write database/config/myplugin \ + plugin_name=myplugin-database-plugin \ + allowed_roles="readonly" \ + myplugins_connection_details=.... + +The following warnings were returned from the Vault server: +* Read access to this endpoint should be controlled via ACLs as it will return the connection details as is, including passwords, if any. +``` + + + + diff --git a/website/source/docs/secrets/databases/mssql.html.md b/website/source/docs/secrets/databases/mssql.html.md index 2d220ddb87e5..0eefe1764260 100644 --- a/website/source/docs/secrets/databases/mssql.html.md +++ b/website/source/docs/secrets/databases/mssql.html.md @@ -56,5 +56,5 @@ The full list of configurable options can be seen in the [MSSQL database plugin API](/api/secret/database/mssql.html) page. Or for more information on the Database secret backend's HTTP API please see the [Database secret -backend API](/api/secret/database/index.html). +backend API](/api/secret/database/index.html) page. diff --git a/website/source/docs/secrets/databases/mysql-maria.html.md b/website/source/docs/secrets/databases/mysql-maria.html.md index bd61cc43b530..76ca193fcfaf 100644 --- a/website/source/docs/secrets/databases/mysql-maria.html.md +++ b/website/source/docs/secrets/databases/mysql-maria.html.md @@ -54,5 +54,5 @@ The full list of configurable options can be seen in the [MySQL database plugin API](/api/secret/database/mysql.html) page. Or for more information on the Database secret backend's HTTP API please see the [Database secret -backend API](/api/secret/database/index.html). +backend API](/api/secret/database/index.html) page. diff --git a/website/source/docs/secrets/databases/postgresql.html.md b/website/source/docs/secrets/databases/postgresql.html.md index e5fee10ef05a..81716132f52b 100644 --- a/website/source/docs/secrets/databases/postgresql.html.md +++ b/website/source/docs/secrets/databases/postgresql.html.md @@ -56,5 +56,5 @@ The full list of configurable options can be seen in the [PostgreSQL database plugin API](/api/secret/database/postgresql.html) page. Or for more information on the Database secret backend's HTTP API please see the [Database secret -backend API](/api/secret/database/index.html). +backend API](/api/secret/database/index.html) page. diff --git a/website/source/layouts/docs.erb b/website/source/layouts/docs.erb index a6afd9c2d921..95787e8e199e 100644 --- a/website/source/layouts/docs.erb +++ b/website/source/layouts/docs.erb @@ -234,6 +234,9 @@ > PostgreSQL + > + Custom + From 78b27fa7650552bcec3984add2bf5a4414c61b11 Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Wed, 3 May 2017 02:13:07 -0700 Subject: [PATCH 135/152] Add API docs --- .../api/secret/databases/cassandra.html.md | 96 +++++ .../source/api/secret/databases/index.html.md | 342 ++++++++++++++++++ .../source/api/secret/databases/mssql.html.md | 60 +++ .../api/secret/databases/mysql-maria.html.md | 60 +++ .../api/secret/databases/postgresql.html.md | 60 +++ .../docs/secrets/databases/custom.html.md | 5 +- website/source/layouts/api.erb | 27 +- 7 files changed, 644 insertions(+), 6 deletions(-) create mode 100644 website/source/api/secret/databases/cassandra.html.md create mode 100644 website/source/api/secret/databases/index.html.md create mode 100644 website/source/api/secret/databases/mssql.html.md create mode 100644 website/source/api/secret/databases/mysql-maria.html.md create mode 100644 website/source/api/secret/databases/postgresql.html.md diff --git a/website/source/api/secret/databases/cassandra.html.md b/website/source/api/secret/databases/cassandra.html.md new file mode 100644 index 000000000000..5e2b5a83603b --- /dev/null +++ b/website/source/api/secret/databases/cassandra.html.md @@ -0,0 +1,96 @@ +--- +layout: "api" +page_title: "Cassandra Database Plugin - HTTP API" +sidebar_current: "docs-http-secret-databases-cassandra-maria" +description: |- + The Cassandra plugin for Vault's Database backend generates database credentials to access Cassandra servers. +--- + +# Cassandra Database Plugin HTTP API + +The Cassandra Database Plugin is one of the supported plugins for the Database +backend. This plugin generates database credentials dynamically based on +configured roles for the Cassandra database. + +## Configure Connection + +In addition to the parameters defined by the [Database +Backend](/api/secret/databases/index.html#configure-connection), this plugin +has a number of parameters to further configure a connection. + +| Method | Path | Produces | +| :------- | :--------------------------- | :--------------------- | +| `POST` | `/database/config/:name` | `204 (empty body)` | + +### Parameters +- `hosts` `(string: )` – Specifies a set of comma-delineated Cassandra + hosts to connect to. + +- `username` `(string: )` – Specifies the username to use for + superuser access. + +- `password` `(string: )` – Specifies the password corresponding to + the given username. + +- `tls` `(bool: true)` – Specifies whether to use TLS when connecting to + Cassandra. + +- `insecure_tls` `(bool: false)` – Specifies whether to skip verification of the + server certificate when using TLS. + +- `pem_bundle` `(string: "")` – Specifies concatenated PEM blocks containing a + certificate and private key; a certificate, private key, and issuing CA + certificate; or just a CA certificate. + +- `pem_json` `(string: "")` – Specifies JSON containing a certificate and + private key; a certificate, private key, and issuing CA certificate; or just a + CA certificate. For convenience format is the same as the output of the + `issue` command from the `pki` backend; see + [the pki documentation](/docs/secrets/pki/index.html). + +- `protocol_version` `(int: 2)` – Specifies the CQL protocol version to use. + +- `connect_timeout` `(string: "5s")` – Specifies the connection timeout to use. + +TLS works as follows: + +- If `tls` is set to true, the connection will use TLS; this happens + automatically if `pem_bundle`, `pem_json`, or `insecure_tls` is set + +- If `insecure_tls` is set to true, the connection will not perform verification + of the server certificate; this also sets `tls` to true + +- If only `issuing_ca` is set in `pem_json`, or the only certificate in + `pem_bundle` is a CA certificate, the given CA certificate will be used for + server certificate verification; otherwise the system CA certificates will be + used + +- If `certificate` and `private_key` are set in `pem_bundle` or `pem_json`, + client auth will be turned on for the connection + +`pem_bundle` should be a PEM-concatenated bundle of a private key + client +certificate, an issuing CA certificate, or both. `pem_json` should contain the +same information; for convenience, the JSON format is the same as that output by +the issue command from the PKI backend. + +### Sample Payload + +```json +{ + "plugin_name": "cassandra-database-plugin", + "allowed_roles": "readonly", + "hosts": "cassandra1.local", + "username": "user", + "password": "pass" +} +``` + +### Sample Request + +``` +$ curl \ + --header "X-Vault-Token: ..." \ + --request POST \ + --data @payload.json \ + https://vault.rocks/v1/cassandra/config/connection +``` diff --git a/website/source/api/secret/databases/index.html.md b/website/source/api/secret/databases/index.html.md new file mode 100644 index 000000000000..9e6015648ae6 --- /dev/null +++ b/website/source/api/secret/databases/index.html.md @@ -0,0 +1,342 @@ +--- +layout: "api" +page_title: "Databases - HTTP API" +sidebar_current: "docs-http-secret-databases" +description: |- + Top page for database secret backend information +--- + +# Database Secret Backend HTTP API + +This is the API documentation for the Vault Database secret backend. For +general information about the usage and operation of the Database backend, +please see the +[Vault Database backend documentation](/docs/secrets/database/index.html). + +This documentation assumes the Database backend is mounted at the +`/database` path in Vault. Since it is possible to mount secret backends at +any location, please update your API calls accordingly. + +## Configure Connection + +This endpoint configures the connection string used to communicate with the +desired database. In addition to the parameters listed here, each Database +plugin has additional, database plugin specifig, parameters for this endpoint. +Please read the HTTP API for the plugin you'd wish to configure to see the full +list of additional parameters. + +| Method | Path | Produces | +| :------- | :--------------------------- | :--------------------- | +| `POST` | `/database/config/:name` | `204 (empty body)` | + +### Parameters +- `name` `(string: )` – Specifies the name for this database + connection. This is specified as part of the URL. + +- `plugin_name` `(string: )` - Specifies the name of the plugin to use + for this connection. + +- `verify_connection` `(bool: true)` – Specifies if the connection is verified + during initial configuration. Defaults to true. + +- `allowed_roles` `(slice: [])` - Array or comma separated string of the roles + allowed to use this connection. Defaults to empty (no roles), if contains a + "*" any role can use this connection. + +### Sample Payload + +```json +{ + "plugin_name": "mysql-database-plugin", + "allowed_roles": "readonly", + "connection_url": "root:mysql@tcp(127.0.0.1:3306)/" +} +``` + +### Sample Request + +``` +$ curl \ + --header "X-Vault-Token: ..." \ + --request POST \ + --data @payload.json \ + https://vault.rocks/v1/database/config/mysql +``` + +## Read Connection + +This endpoint returns the configuration settings for a connection. + +| Method | Path | Produces | +| :------- | :--------------------------- | :--------------------- | +| `GET` | `/database/config/:name` | `200 application/json` | + +### Parameters + +- `name` `(string: )` – Specifies the name of the connection to read. + This is specified as part of the URL. + +### Sample Request + +``` +$ curl \ + --header "X-Vault-Token: ..." \ + --request GET \ + https://vault.rocks/v1/database/config/mysql +``` + +### Sample Response + +```json +{ + "data": { + "allowed_roles": [ + "readonly" + ], + "connection_details": { + "connection_url": "root:mysql@tcp(127.0.0.1:3306)/", + }, + "plugin_name": "mysql-database-plugin" + }, +} +``` + +## Delete Connection + +This endpoint deletes a connection. + +| Method | Path | Produces | +| :------- | :--------------------------- | :--------------------- | +| `DELETE` | `/database/config/:name` | `204 (empty body)` | + +### Parameters + +- `name` `(string: )` – Specifies the name of the connection to delete. + This is specified as part of the URL. + +### Sample Request + +``` +$ curl \ + --header "X-Vault-Token: ..." \ + --request DELETE \ + https://vault.rocks/v1/database/config/mysql +``` + +## Reset Connection + +This endpoint closes a connection and it's underlying plugin and restarts it +with the configuration stored in the barrier. + +| Method | Path | Produces | +| :------- | :--------------------------- | :--------------------- | +| `POST` | `/database/reset/:name` | `204 (empty body)` | + +### Parameters + +- `name` `(string: )` – Specifies the name of the connection to delete. + This is specified as part of the URL. + +### Sample Request + +``` +$ curl \ + --header "X-Vault-Token: ..." \ + --request POST \ + https://vault.rocks/v1/database/reset/mysql +``` + +## Create Role + +This endpoint creates or updates a role definition. + +| Method | Path | Produces | +| :------- | :--------------------------- | :--------------------- | +| `POST` | `/database/roles/:name` | `204 (empty body)` | + +### Parameters + +- `name` `(string: )` – Specifies the name of the role to create. This + is specified as part of the URL. + +- `db_name` `(string: )` - The name of the database connection to use + for this role. + +- `default_ttl` `(string: )` - Specifies the TTL for the lease + associated with this role. + +- `max_ttl` `(string: )` - Specifies the maximum TTL for the lease + associated with this role. + +- `creation_statements` `(string: )` – Specifies the database + statements executed to create and configure a user. Must be a + semicolon-separated string, a base64-encoded semicolon-separated string, a + serialized JSON string array, or a base64-encoded serialized JSON string + array. The '{{name}}', '{{password}}' and '{{expiration}}' values will be + substituted. + +- `revocation_statements` `(string: "")` – Specifies the database statements to + be executed to revoke a user. Must be a semicolon-separated string, a + base64-encoded semicolon-separated string, a serialized JSON string array, or + a base64-encoded serialized JSON string array. The '{{name}}' value will be + substituted. + +- `rollback_statements` `(string: "")` – Specifies the database statements to be + executed rollback a create operation in the event of an error. Not every + plugin type will support this functionality. Must be a semicolon-separated + string, a base64-encoded semicolon-separated string, a serialized JSON string + array, or a base64-encoded serialized JSON string array. The '{{name}}' value + will be substituted. + +- `renew_statements` `(string: "")` – Specifies the database statements to be + executed to renew a user. Not every plugin type will support this + functionality. Must be a semicolon-separated string, a base64-encoded + semicolon-separated string, a serialized JSON string array, or a + base64-encoded serialized JSON string array. The '{{name}}' and + '{{expiration}}` values will be substituted. + + +### Sample Payload + +```json +{ + "db_name": "mysql", + "creation_statements": "CREATE USER '{{name}}'@'%' IDENTIFIED BY '{{password}}';GRANT SELECT ON *.* TO '{{name}}'@'%';", + "default_ttl": "1h", + "max_ttl": "24h" +} +``` + +### Sample Request + +``` +$ curl \ + --header "X-Vault-Token: ..." \ + --request POST \ + --data @payload.json \ + https://vault.rocks/v1/database/roles/my-role +``` + +## Read Role + +This endpoint queries the role definition. + +| Method | Path | Produces | +| :------- | :--------------------------- | :--------------------- | +| `GET` | `/database/roles/:name` | `200 application/json` | + +### Parameters + +- `name` `(string: )` – Specifies the name of the role to read. This + is specified as part of the URL. + +### Sample Request + +``` +$ curl \ + --header "X-Vault-Token: ..." \ + https://vault.rocks/v1/database/roles/my-role +``` + +### Sample Response + +```json +{ + "data": { + "creation_statements": "CREATE ROLE \"{{name}}\" WITH LOGIN PASSWORD '{{password}}' VALID UNTIL '{{expiration}}'; GRANT SELECT ON ALL TABLES IN SCHEMA public TO \"{{name}}\";", + "db_name": "mysql", + "default_ttl": 3600, + "max_ttl": 86400, + "renew_statements": "", + "revocation_statements": "", + "rollback_statements": "" + }, +} +``` + +## List Roles + +This endpoint returns a list of available roles. Only the role names are +returned, not any values. + +| Method | Path | Produces | +| :------- | :--------------------------- | :--------------------- | +| `LIST` | `/database/roles` | `200 application/json` | + +### Sample Request + +``` +$ curl \ + --header "X-Vault-Token: ..." \ + --request LIST \ + https://vault.rocks/v1/database/roles +``` + +### Sample Response + +```json +{ + "auth": null, + "data": { + "keys": ["dev", "prod"] + }, + "lease_duration": 2764800, + "lease_id": "", + "renewable": false +} +``` + +## Delete Role + +This endpoint deletes the role definition. + +| Method | Path | Produces | +| :------- | :--------------------------- | :--------------------- | +| `DELETE` | `/database/roles/:name` | `204 (empty body)` | + +### Parameters + +- `name` `(string: )` – Specifies the name of the role to delete. This + is specified as part of the URL. + +### Sample Request + +``` +$ curl \ + --header "X-Vault-Token: ..." \ + --request DELETE \ + https://vault.rocks/v1/database/roles/my-role +``` + +## Generate Credentials + +This endpoint generates a new set of dynamic credentials based on the named +role. + +| Method | Path | Produces | +| :------- | :--------------------------- | :--------------------- | +| `GET` | `/database/creds/:name` | `200 application/json` | + +### Parameters + +- `name` `(string: )` – Specifies the name of the role to create + credentials against. This is specified as part of the URL. + +### Sample Request + +``` +$ curl \ + --header "X-Vault-Token: ..." \ + https://vault.rocks/v1/database/creds/my-role +``` + +### Sample Response + +```json +{ + "data": { + "username": "root-1430158508-126", + "password": "132ae3ef-5a64-7499-351e-bfe59f3a2a21" + } +} +``` diff --git a/website/source/api/secret/databases/mssql.html.md b/website/source/api/secret/databases/mssql.html.md new file mode 100644 index 000000000000..09893df455ac --- /dev/null +++ b/website/source/api/secret/databases/mssql.html.md @@ -0,0 +1,60 @@ +--- +layout: "api" +page_title: "MSSQL Database Plugin - HTTP API" +sidebar_current: "docs-http-secret-databases-mssql-maria" +description: |- + The MSSQL plugin for Vault's Database backend generates database credentials to access MSSQL servers. +--- + +# MSSQL Database Plugin HTTP API + +The MSSQL Database Plugin is one of the supported plugins for the Database +backend. This plugin generates database credentials dynamically based on +configured roles for the MSSQL database. + +## Configure Connection + +In addition to the parameters defined by the [Database +Backend](/api/secret/databases/index.html#configure-connection), this plugin +has a number of parameters to further configure a connection. + +| Method | Path | Produces | +| :------- | :--------------------------- | :--------------------- | +| `POST` | `/database/config/:name` | `204 (empty body)` | + +### Parameters +- `connection_url` `(string: )` - Specifies the MSSQL DSN. + +- `max_open_connections` `(int: 2)` - Speficies the name of the plugin to use + for this connection. + +- `max_idle_connections` `(int: 0)` - Specifies the maximum number of idle + connections to the database. A zero uses the value of `max_open_connections` + and a negative value disables idle connections. If larger than + `max_open_connections` it will be reduced to be equal. + +- `max_connection_lifetime` `(string: "0s")` - Specifies the maximum amount of + time a connection may be reused. If <= 0s connections are reused forever. + +### Sample Payload + +```json +{ + "plugin_name": "mssql-database-plugin", + "allowed_roles": "readonly", + "connection_url": "sqlserver://sa:yourStrong(!)Password@localhost:1433", + "max_open_connections": 5, + "max_connection_lifetime": "5s", +} +``` + +### Sample Request + +``` +$ curl \ + --header "X-Vault-Token: ..." \ + --request POST \ + --data @payload.json \ + https://vault.rocks/v1/database/config/mssql +``` + diff --git a/website/source/api/secret/databases/mysql-maria.html.md b/website/source/api/secret/databases/mysql-maria.html.md new file mode 100644 index 000000000000..981506798f69 --- /dev/null +++ b/website/source/api/secret/databases/mysql-maria.html.md @@ -0,0 +1,60 @@ +--- +layout: "api" +page_title: "MySQL/MariaDB Database Plugin - HTTP API" +sidebar_current: "docs-http-secret-databases-mysql-maria" +description: |- + The MySQL/MariaDB plugin for Vault's Database backend generates database credentials to access MySQL and MariaDB servers. +--- + +# MySQL/MariaDB Database Plugin HTTP API + +The MySQL Database Plugin is one of the supported plugins for the Database +backend. This plugin generates database credentials dynamically based on +configured roles for the MySQL database. + +## Configure Connection + +In addition to the parameters defined by the [Database +Backend](/api/secret/databases/index.html#configure-connection), this plugin +has a number of parameters to further configure a connection. + +| Method | Path | Produces | +| :------- | :--------------------------- | :--------------------- | +| `POST` | `/database/config/:name` | `204 (empty body)` | + +### Parameters +- `connection_url` `(string: )` - Specifies the MySQL DSN. + +- `max_open_connections` `(int: 2)` - Speficies the name of the plugin to use + for this connection. + +- `max_idle_connections` `(int: 0)` - Specifies the maximum number of idle + connections to the database. A zero uses the value of `max_open_connections` + and a negative value disables idle connections. If larger than + `max_open_connections` it will be reduced to be equal. + +- `max_connection_lifetime` `(string: "0s")` - Specifies the maximum amount of + time a connection may be reused. If <= 0s connections are reused forever. + +### Sample Payload + +```json +{ + "plugin_name": "mysql-database-plugin", + "allowed_roles": "readonly", + "connection_url": "root:mysql@tcp(127.0.0.1:3306)/" + "max_open_connections": 5, + "max_connection_lifetime": "5s", +} +``` + +### Sample Request + +``` +$ curl \ + --header "X-Vault-Token: ..." \ + --request POST \ + --data @payload.json \ + https://vault.rocks/v1/database/config/mysql +``` + diff --git a/website/source/api/secret/databases/postgresql.html.md b/website/source/api/secret/databases/postgresql.html.md new file mode 100644 index 000000000000..5ff6b8022e37 --- /dev/null +++ b/website/source/api/secret/databases/postgresql.html.md @@ -0,0 +1,60 @@ +--- +layout: "api" +page_title: "PostgreSQL Database Plugin - HTTP API" +sidebar_current: "docs-http-secret-databases-postgresql-maria" +description: |- + The PostgreSQL plugin for Vault's Database backend generates database credentials to access PostgreSQL servers. +--- + +# PostgreSQL Database Plugin HTTP API + +The PostgreSQL Database Plugin is one of the supported plugins for the Database +backend. This plugin generates database credentials dynamically based on +configured roles for the PostgreSQL database. + +## Configure Connection + +In addition to the parameters defined by the [Database +Backend](/api/secret/databases/index.html#configure-connection), this plugin +has a number of parameters to further configure a connection. + +| Method | Path | Produces | +| :------- | :--------------------------- | :--------------------- | +| `POST` | `/database/config/:name` | `204 (empty body)` | + +### Parameters +- `connection_url` `(string: )` - Specifies the PostgreSQL DSN. + +- `max_open_connections` `(int: 2)` - Speficies the name of the plugin to use + for this connection. + +- `max_idle_connections` `(int: 0)` - Specifies the maximum number of idle + connections to the database. A zero uses the value of `max_open_connections` + and a negative value disables idle connections. If larger than + `max_open_connections` it will be reduced to be equal. + +- `max_connection_lifetime` `(string: "0s")` - Specifies the maximum amount of + time a connection may be reused. If <= 0s connections are reused forever. + +### Sample Payload + +```json +{ + "plugin_name": "postgresql-database-plugin", + "allowed_roles": "readonly", + "connection_url": "postgresql://root:root@localhost:5432/postgres", + "max_open_connections": 5, + "max_connection_lifetime": "5s", +} +``` + +### Sample Request + +``` +$ curl \ + --header "X-Vault-Token: ..." \ + --request POST \ + --data @payload.json \ + https://vault.rocks/v1/database/config/postgresql +``` + diff --git a/website/source/docs/secrets/databases/custom.html.md b/website/source/docs/secrets/databases/custom.html.md index 5911e1d809c2..7d21a19c98d1 100644 --- a/website/source/docs/secrets/databases/custom.html.md +++ b/website/source/docs/secrets/databases/custom.html.md @@ -11,7 +11,8 @@ description: |- The Database backend allows new functionality to be added through a plugin interface without needing to modify vault's core code. This allows you write your own code to generate credentials in any database you wish. It also allows -databases that require dynamically linked libraries to be used with vault. +databases that require dynamically linked libraries to be used as plugins while +keeping Vault itself statically linked. ~> **Advanced topic!** Plugin development is a highly advanced topic in Vault, and is not required knowledge for day-to-day usage. @@ -81,7 +82,7 @@ Replacing `MyPlugin` with the actual implementation of your plugin. The second parameter to `Serve` takes in an optional vault `api.TLSConfig` for configuring the plugin to communicate with vault for the initial unwrap call. -This if useful if your vault setup requires client certificate checks. This +This is useful if your vault setup requires client certificate checks. This config wont be used once the plugin unwraps its own TLS cert and key. ## Running your plugin diff --git a/website/source/layouts/api.erb b/website/source/layouts/api.erb index c209937bc0bb..ea8e35624fd4 100644 --- a/website/source/layouts/api.erb +++ b/website/source/layouts/api.erb @@ -21,7 +21,7 @@ AWS > - Cassandra + Cassandra (Deprecated) > Consul @@ -29,6 +29,25 @@ > Cubbyhole + + > + Databases (Beta) + + + > Generic @@ -36,16 +55,16 @@ MongoDB > - MSSQL + MSSQL (Deprecated) > - MySQL + MySQL (Deprecated) > PKI > - PostgreSQL + PostgreSQL (Deprecated) > RabbitMQ From 799cd3c7c7f694f1e5216e34c66d8b8e2cb599d1 Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Wed, 3 May 2017 10:25:12 -0700 Subject: [PATCH 136/152] Upate links in docs --- website/source/api/secret/databases/index.html.md | 2 +- website/source/docs/secrets/databases/cassandra.html.md | 6 +++--- website/source/docs/secrets/databases/index.html.md | 2 +- website/source/docs/secrets/databases/mssql.html.md | 6 +++--- website/source/docs/secrets/databases/mysql-maria.html.md | 6 +++--- website/source/docs/secrets/databases/postgresql.html.md | 6 +++--- 6 files changed, 14 insertions(+), 14 deletions(-) diff --git a/website/source/api/secret/databases/index.html.md b/website/source/api/secret/databases/index.html.md index 9e6015648ae6..f55998aceb8a 100644 --- a/website/source/api/secret/databases/index.html.md +++ b/website/source/api/secret/databases/index.html.md @@ -11,7 +11,7 @@ description: |- This is the API documentation for the Vault Database secret backend. For general information about the usage and operation of the Database backend, please see the -[Vault Database backend documentation](/docs/secrets/database/index.html). +[Vault Database backend documentation](/docs/secrets/databases/index.html). This documentation assumes the Database backend is mounted at the `/database` path in Vault. Since it is possible to mount secret backends at diff --git a/website/source/docs/secrets/databases/cassandra.html.md b/website/source/docs/secrets/databases/cassandra.html.md index b3d87f7ed238..1d8468ad3a71 100644 --- a/website/source/docs/secrets/databases/cassandra.html.md +++ b/website/source/docs/secrets/databases/cassandra.html.md @@ -14,7 +14,7 @@ The Cassandra Database Plugin is one of the supported plugins for the Database backend. This plugin generates database credentials dynamically based on configured roles for the Cassandra database. -See the [Database Backend](/docs/secret/database/index.html) docs for more +See the [Database Backend](/docs/secrets/databases/index.html) docs for more information about setting up the Database Backend. ## Quick Start @@ -55,8 +55,8 @@ This role can be used to retrieve a new set of credentials by querying the ## API The full list of configurable options can be seen in the [Cassandra database -plugin API](/api/secret/database/cassandra.html) page. +plugin API](/api/secret/databases/cassandra.html) page. Or for more information on the Database secret backend's HTTP API please see the [Database secret -backend API](/api/secret/database/index.html) page. +backend API](/api/secret/databases/index.html) page. diff --git a/website/source/docs/secrets/databases/index.html.md b/website/source/docs/secrets/databases/index.html.md index cf366d9c90ef..c88699c44705 100644 --- a/website/source/docs/secrets/databases/index.html.md +++ b/website/source/docs/secrets/databases/index.html.md @@ -93,5 +93,5 @@ username v-root-e2978cd0- ## API The Database secret backend has a full HTTP API. Please see the [Database secret -backend API](/api/secret/database/index.html) for more details. +backend API](/api/secret/databases/index.html) for more details. diff --git a/website/source/docs/secrets/databases/mssql.html.md b/website/source/docs/secrets/databases/mssql.html.md index 0eefe1764260..fec8924b3fa8 100644 --- a/website/source/docs/secrets/databases/mssql.html.md +++ b/website/source/docs/secrets/databases/mssql.html.md @@ -14,7 +14,7 @@ The MSSQL Database Plugin is one of the supported plugins for the Database backend. This plugin generates database credentials dynamically based on configured roles for the MSSQL database. -See the [Database Backend](/docs/secret/database/index.html) docs for more +See the [Database Backend](/docs/secrets/databases/index.html) docs for more information about setting up the Database Backend. ## Quick Start @@ -53,8 +53,8 @@ This role can now be used to retrieve a new set of credentials by querying the ## API The full list of configurable options can be seen in the [MSSQL database -plugin API](/api/secret/database/mssql.html) page. +plugin API](/api/secret/databases/mssql.html) page. Or for more information on the Database secret backend's HTTP API please see the [Database secret -backend API](/api/secret/database/index.html) page. +backend API](/api/secret/databases/index.html) page. diff --git a/website/source/docs/secrets/databases/mysql-maria.html.md b/website/source/docs/secrets/databases/mysql-maria.html.md index 76ca193fcfaf..c5eea4b7bac9 100644 --- a/website/source/docs/secrets/databases/mysql-maria.html.md +++ b/website/source/docs/secrets/databases/mysql-maria.html.md @@ -14,7 +14,7 @@ The MySQL Database Plugin is one of the supported plugins for the Database backend. This plugin generates database credentials dynamically based on configured roles for the MySQL database. -See the [Database Backend](/docs/secret/database/index.html) docs for more +See the [Database Backend](/docs/secrets/databases/index.html) docs for more information about setting up the Database Backend. ## Quick Start @@ -51,8 +51,8 @@ This role can now be used to retrieve a new set of credentials by querying the ## API The full list of configurable options can be seen in the [MySQL database -plugin API](/api/secret/database/mysql.html) page. +plugin API](/api/secret/databases/mysql-maria.html) page. Or for more information on the Database secret backend's HTTP API please see the [Database secret -backend API](/api/secret/database/index.html) page. +backend API](/api/secret/databases/index.html) page. diff --git a/website/source/docs/secrets/databases/postgresql.html.md b/website/source/docs/secrets/databases/postgresql.html.md index 81716132f52b..72601e34fc4c 100644 --- a/website/source/docs/secrets/databases/postgresql.html.md +++ b/website/source/docs/secrets/databases/postgresql.html.md @@ -14,7 +14,7 @@ The PostgreSQL Database Plugin is one of the supported plugins for the Database backend. This plugin generates database credentials dynamically based on configured roles for the PostgreSQL database. -See the [Database Backend](/docs/secret/database/index.html) docs for more +See the [Database Backend](/docs/secrets/databases/index.html) docs for more information about setting up the Database Backend. ## Quick Start @@ -53,8 +53,8 @@ This role can be used to retrieve a new set of credentials by querying the ## API The full list of configurable options can be seen in the [PostgreSQL database -plugin API](/api/secret/database/postgresql.html) page. +plugin API](/api/secret/databases/postgresql.html) page. Or for more information on the Database secret backend's HTTP API please see the [Database secret -backend API](/api/secret/database/index.html) page. +backend API](/api/secret/databases/index.html) page. From 311acb34a5a1e50ad81392a51827e418df7dfe25 Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Wed, 3 May 2017 11:43:24 -0700 Subject: [PATCH 137/152] Add the plugins catalog API docs --- .../source/api/system/plugins-catalog.html.md | 155 ++++++++++++++++++ website/source/layouts/api.erb | 3 + 2 files changed, 158 insertions(+) create mode 100644 website/source/api/system/plugins-catalog.html.md diff --git a/website/source/api/system/plugins-catalog.html.md b/website/source/api/system/plugins-catalog.html.md new file mode 100644 index 000000000000..b955261949d1 --- /dev/null +++ b/website/source/api/system/plugins-catalog.html.md @@ -0,0 +1,155 @@ +--- +layout: "api" +page_title: "/sys/plugins/catalog - HTTP API" +sidebar_current: "docs-http-system-plugins-catalog" +description: |- + The `/sys/plugins/catalog` endpoint is used to manage plugins. +--- + +# `/sys/plugins/catalog` + +The `/sys/plugins/catalog` endpoint is used to list, register, update, and +remove plugins in Vault's catalog. Plugins must be registered before use, and +once registered backends can use the plugin by querying the catalog. + +## List Plugins + +This endpoint lists the plugins in the catalog. + +| Method | Path | Produces | +| :------- | :--------------------------- | :--------------------- | +| `LIST` | `/sys/plugins/catalog/` | `200 application/json` | + +### Sample Request + +``` +$ curl \ + --header "X-Vault-Token: ..." \ + --request LIST + https://vault.rocks/v1/sys/plugins/catalog +``` + +### Sample Response + +```javascript +{ + "data": { + "keys": [ + "cassandra-database-plugin", + "mssql-database-plugin", + "mysql-database-plugin", + "postgresql-database-plugin" + ] + } +} +``` + +## Register Plugin + +This endpoint registers a new plugin, or updates an existing one with the +supplied name. + +- **`sudo` required** – This endpoint requires `sudo` capability in addition to + any path-specific capabilities. + +| Method | Path | Produces | +| :------- | :--------------------------- | :--------------------- | +| `PUT` | `/sys/plugins/catalog/:name` | `204 (empty body)` | + +### Parameters + +- `name` `(string: )` – Specifies the name for this plugin. The name + is what is used to look up plugins in the catalog. This is part of the request + URL. + +- `sha_256` `(string: )` – This is the SHA256 sum of the plugin's + binary. Before a plugin is run it's SHA will be checked against this value, if + they do not match the plugin can not be run. + +- `command` `(string: )` – Specifies the command used to execute the + plugin. This is relative to the plugin directory. e.g. `"myplugin + --my_flag=1"` + +### Sample Payload + +```json +{ + "sha_256": "d130b9a0fbfddef9709d8ff92e5e6053ccd246b78632fc03b8548457026961e9", + "command": "mysql-database-plugin" +} +``` + +### Sample Request + +``` +$ curl \ + --header "X-Vault-Token: ..." \ + --request PUT \ + --data @payload.json \ + https://vault.rocks/v1/sys/plugins/catalog/example-plugin +``` + +## Read Plugin + +This endpoint returns the configuration data for the plugin with the given name. + +- **`sudo` required** – This endpoint requires `sudo` capability in addition to + any path-specific capabilities. + +| Method | Path | Produces | +| :------- | :--------------------------- | :--------------------- | +| `GET` | `/sys/plugins/catalog/:name` | `200 application/json` | + +### Parameters + +- `name` `(string: )` – Specifies the name of the plugin to retrieve. + This is part of the request URL. + +### Sample Request + +``` +$ curl \ + --header "X-Vault-Token: ..." \ + --request GET \ + https://vault.rocks/v1/sys/plugins/catalog/example-plugin +``` + +### Sample Response + +```javascript +{ + "data": { + "plugin": { + "args": [], + "builtin": false, + "command": "/tmp/vault-plugins/mysql-database-plugin", + "name": "example-plugin", + "sha256": "0TC5oPv93vlwnY/5Ll5gU8zSRreGMvwDuFSEVwJpYek=" + } + } +} +``` +## Remove Plugin from Catalog + +This endpoint removes the plugin with the given name. + +- **`sudo` required** – This endpoint requires `sudo` capability in addition to + any path-specific capabilities. + +| Method | Path | Produces | +| :------- | :--------------------------- | :--------------------- | +| `DELETE` | `/sys/plugins/catalog/:name` | `204 (empty body)` | + +### Parameters + +- `name` `(string: )` – Specifies the name of the plugin to delete. + This is part of the request URL. + +### Sample Request + +``` +$ curl \ + --header "X-Vault-Token: ..." \ + --request DELETE \ + https://vault.rocks/v1/sys/plugins/catalog/example-plugin +``` diff --git a/website/source/layouts/api.erb b/website/source/layouts/api.erb index ea8e35624fd4..c6e92d026eb3 100644 --- a/website/source/layouts/api.erb +++ b/website/source/layouts/api.erb @@ -120,6 +120,9 @@ > /sys/mounts + > + /sys/plugins/catalog + > /sys/policy From f424a9ad795df6ba758c9c28f522d4f93180230a Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Wed, 3 May 2017 13:01:05 -0700 Subject: [PATCH 138/152] Use log to output errors instead of fmt --- plugins/database/cassandra/cassandra-database-plugin/main.go | 4 ++-- plugins/database/mssql/mssql-database-plugin/main.go | 4 ++-- plugins/database/mysql/mysql-database-plugin/main.go | 4 ++-- .../database/postgresql/postgresql-database-plugin/main.go | 4 ++-- 4 files changed, 8 insertions(+), 8 deletions(-) diff --git a/plugins/database/cassandra/cassandra-database-plugin/main.go b/plugins/database/cassandra/cassandra-database-plugin/main.go index bb3f44142195..c70997897e6d 100644 --- a/plugins/database/cassandra/cassandra-database-plugin/main.go +++ b/plugins/database/cassandra/cassandra-database-plugin/main.go @@ -1,7 +1,7 @@ package main import ( - "fmt" + "log" "os" "github.com/hashicorp/vault/helper/pluginutil" @@ -15,7 +15,7 @@ func main() { err := cassandra.Run(apiClientMeta.GetTLSConfig()) if err != nil { - fmt.Println(err) + log.Println(err) os.Exit(1) } } diff --git a/plugins/database/mssql/mssql-database-plugin/main.go b/plugins/database/mssql/mssql-database-plugin/main.go index d52fd13db0db..5f05c5dffa0a 100644 --- a/plugins/database/mssql/mssql-database-plugin/main.go +++ b/plugins/database/mssql/mssql-database-plugin/main.go @@ -1,7 +1,7 @@ package main import ( - "fmt" + "log" "os" "github.com/hashicorp/vault/helper/pluginutil" @@ -15,7 +15,7 @@ func main() { err := mssql.Run(apiClientMeta.GetTLSConfig()) if err != nil { - fmt.Println(err) + log.Println(err) os.Exit(1) } } diff --git a/plugins/database/mysql/mysql-database-plugin/main.go b/plugins/database/mysql/mysql-database-plugin/main.go index a9389f50420d..249e5afeef2f 100644 --- a/plugins/database/mysql/mysql-database-plugin/main.go +++ b/plugins/database/mysql/mysql-database-plugin/main.go @@ -1,7 +1,7 @@ package main import ( - "fmt" + "log" "os" "github.com/hashicorp/vault/helper/pluginutil" @@ -15,7 +15,7 @@ func main() { err := mysql.Run(apiClientMeta.GetTLSConfig()) if err != nil { - fmt.Println(err) + log.Println(err) os.Exit(1) } } diff --git a/plugins/database/postgresql/postgresql-database-plugin/main.go b/plugins/database/postgresql/postgresql-database-plugin/main.go index e6acb0584748..ac3cf95a7596 100644 --- a/plugins/database/postgresql/postgresql-database-plugin/main.go +++ b/plugins/database/postgresql/postgresql-database-plugin/main.go @@ -1,7 +1,7 @@ package main import ( - "fmt" + "log" "os" "github.com/hashicorp/vault/helper/pluginutil" @@ -15,7 +15,7 @@ func main() { err := postgresql.Run(apiClientMeta.GetTLSConfig()) if err != nil { - fmt.Println(err) + log.Println(err) os.Exit(1) } } From c381b0093c604d5b31fec0875357ed3623fd1059 Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Wed, 3 May 2017 13:11:30 -0700 Subject: [PATCH 139/152] Use ParseDurationSecond to parse the timeouts in connutil --- plugins/helper/database/connutil/cassandra.go | 42 +++++++++++-------- plugins/helper/database/connutil/sql.go | 15 +++---- 2 files changed, 32 insertions(+), 25 deletions(-) diff --git a/plugins/helper/database/connutil/cassandra.go b/plugins/helper/database/connutil/cassandra.go index 1babc3cbde8b..27fb2519587b 100644 --- a/plugins/helper/database/connutil/cassandra.go +++ b/plugins/helper/database/connutil/cassandra.go @@ -11,28 +11,30 @@ import ( "github.com/gocql/gocql" "github.com/hashicorp/vault/helper/certutil" + "github.com/hashicorp/vault/helper/parseutil" "github.com/hashicorp/vault/helper/tlsutil" ) // CassandraConnectionProducer implements ConnectionProducer and provides an // interface for cassandra databases to make connections. type CassandraConnectionProducer struct { - Hosts string `json:"hosts" structs:"hosts" mapstructure:"hosts"` - Username string `json:"username" structs:"username" mapstructure:"username"` - Password string `json:"password" structs:"password" mapstructure:"password"` - TLS bool `json:"tls" structs:"tls" mapstructure:"tls"` - InsecureTLS bool `json:"insecure_tls" structs:"insecure_tls" mapstructure:"insecure_tls"` - Certificate string `json:"certificate" structs:"certificate" mapstructure:"certificate"` - PrivateKey string `json:"private_key" structs:"private_key" mapstructure:"private_key"` - IssuingCA string `json:"issuing_ca" structs:"issuing_ca" mapstructure:"issuing_ca"` - ProtocolVersion int `json:"protocol_version" structs:"protocol_version" mapstructure:"protocol_version"` - ConnectTimeout int `json:"connect_timeout" structs:"connect_timeout" mapstructure:"connect_timeout"` - TLSMinVersion string `json:"tls_min_version" structs:"tls_min_version" mapstructure:"tls_min_version"` - Consistency string `json:"consistency" structs:"consistency" mapstructure:"consistency"` - - Initialized bool - Type string - session *gocql.Session + Hosts string `json:"hosts" structs:"hosts" mapstructure:"hosts"` + Username string `json:"username" structs:"username" mapstructure:"username"` + Password string `json:"password" structs:"password" mapstructure:"password"` + TLS bool `json:"tls" structs:"tls" mapstructure:"tls"` + InsecureTLS bool `json:"insecure_tls" structs:"insecure_tls" mapstructure:"insecure_tls"` + Certificate string `json:"certificate" structs:"certificate" mapstructure:"certificate"` + PrivateKey string `json:"private_key" structs:"private_key" mapstructure:"private_key"` + IssuingCA string `json:"issuing_ca" structs:"issuing_ca" mapstructure:"issuing_ca"` + ProtocolVersion int `json:"protocol_version" structs:"protocol_version" mapstructure:"protocol_version"` + ConnectTimeoutRaw interface{} `json:"connect_timeout" structs:"connect_timeout" mapstructure:"connect_timeout"` + TLSMinVersion string `json:"tls_min_version" structs:"tls_min_version" mapstructure:"tls_min_version"` + Consistency string `json:"consistency" structs:"consistency" mapstructure:"consistency"` + + connectTimeout time.Duration + Initialized bool + Type string + session *gocql.Session sync.Mutex } @@ -46,6 +48,11 @@ func (c *CassandraConnectionProducer) Initialize(conf map[string]interface{}, ve } c.Initialized = true + c.connectTimeout, err = parseutil.ParseDurationSecond(c.ConnectTimeoutRaw) + if err != nil { + return fmt.Errorf("invalid connect_timeout: %s", err) + } + if verifyConnection { if _, err := c.Connection(); err != nil { return fmt.Errorf("error Initalizing Connection: %s", err) @@ -101,8 +108,7 @@ func (c *CassandraConnectionProducer) createSession() (*gocql.Session, error) { clusterConfig.ProtoVersion = 2 } - clusterConfig.Timeout = time.Duration(c.ConnectTimeout) * time.Second - + clusterConfig.Timeout = c.connectTimeout if c.TLS { var tlsConfig *tls.Config if len(c.Certificate) > 0 || len(c.IssuingCA) > 0 { diff --git a/plugins/helper/database/connutil/sql.go b/plugins/helper/database/connutil/sql.go index 0bfc5f9f684f..4a636856037d 100644 --- a/plugins/helper/database/connutil/sql.go +++ b/plugins/helper/database/connutil/sql.go @@ -10,19 +10,20 @@ import ( // Import sql drivers _ "github.com/denisenkom/go-mssqldb" _ "github.com/go-sql-driver/mysql" + "github.com/hashicorp/vault/helper/parseutil" _ "github.com/lib/pq" "github.com/mitchellh/mapstructure" ) // SQLConnectionProducer implements ConnectionProducer and provides a generic producer for most sql databases type SQLConnectionProducer struct { - ConnectionURL string `json:"connection_url" structs:"connection_url" mapstructure:"connection_url"` - MaxOpenConnections int `json:"max_open_connections" structs:"max_open_connections" mapstructure:"max_open_connections"` - MaxIdleConnections int `json:"max_idle_connections" structs:"max_idle_connections" mapstructure:"max_idle_connections"` - MaxConnectionLifetimeRaw string `json:"max_connection_lifetime" structs:"max_connection_lifetime" mapstructure:"max_connection_lifetime"` + ConnectionURL string `json:"connection_url" structs:"connection_url" mapstructure:"connection_url"` + MaxOpenConnections int `json:"max_open_connections" structs:"max_open_connections" mapstructure:"max_open_connections"` + MaxIdleConnections int `json:"max_idle_connections" structs:"max_idle_connections" mapstructure:"max_idle_connections"` + MaxConnectionLifetimeRaw interface{} `json:"max_connection_lifetime" structs:"max_connection_lifetime" mapstructure:"max_connection_lifetime"` Type string - MaxConnectionLifetime time.Duration + maxConnectionLifetime time.Duration Initialized bool db *sql.DB sync.Mutex @@ -51,7 +52,7 @@ func (c *SQLConnectionProducer) Initialize(conf map[string]interface{}, verifyCo c.MaxConnectionLifetimeRaw = "0s" } - c.MaxConnectionLifetime, err = time.ParseDuration(c.MaxConnectionLifetimeRaw) + c.maxConnectionLifetime, err = parseutil.ParseDurationSecond(c.MaxConnectionLifetimeRaw) if err != nil { return fmt.Errorf("invalid max_connection_lifetime: %s", err) } @@ -110,7 +111,7 @@ func (c *SQLConnectionProducer) Connection() (interface{}, error) { // since the request rate shouldn't be high. c.db.SetMaxOpenConns(c.MaxOpenConnections) c.db.SetMaxIdleConns(c.MaxIdleConnections) - c.db.SetConnMaxLifetime(c.MaxConnectionLifetime) + c.db.SetConnMaxLifetime(c.maxConnectionLifetime) return c.db, nil } From 657826d2743580d3d719f15bfb5d86dd17b27c2a Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Wed, 3 May 2017 13:33:56 -0700 Subject: [PATCH 140/152] Add the other mysql plugin types with the correct username length settings --- helper/builtinplugins/builtin.go | 8 +++++- plugins/database/mysql/mysql.go | 48 ++++++++++++++++++++------------ 2 files changed, 37 insertions(+), 19 deletions(-) diff --git a/helper/builtinplugins/builtin.go b/helper/builtinplugins/builtin.go index 3dec8588be9d..8e6ed22ef674 100644 --- a/helper/builtinplugins/builtin.go +++ b/helper/builtinplugins/builtin.go @@ -10,7 +10,13 @@ import ( type BuiltinFactory func() (interface{}, error) var plugins map[string]BuiltinFactory = map[string]BuiltinFactory{ - "mysql-database-plugin": mysql.New, + // These four plugins all use the same mysql implementation but with + // different username settings passed by the constructor. + "mysql-database-plugin": mysql.New(mysql.DisplayNameLen, mysql.UsernameLen), + "aurora-database-plugin": mysql.New(mysql.LegacyDisplayNameLen, mysql.LegacyUsernameLen), + "rds-database-plugin": mysql.New(mysql.LegacyDisplayNameLen, mysql.LegacyUsernameLen), + "mysql-legacy-database-plugin": mysql.New(mysql.LegacyDisplayNameLen, mysql.LegacyUsernameLen), + "postgresql-database-plugin": postgresql.New, "mssql-database-plugin": mssql.New, "cassandra-database-plugin": cassandra.New, diff --git a/plugins/database/mysql/mysql.go b/plugins/database/mysql/mysql.go index 7a44d7341f1a..b875af520d56 100644 --- a/plugins/database/mysql/mysql.go +++ b/plugins/database/mysql/mysql.go @@ -14,11 +14,20 @@ import ( "github.com/hashicorp/vault/plugins/helper/database/dbutil" ) -const defaultMysqlRevocationStmts = ` - REVOKE ALL PRIVILEGES, GRANT OPTION FROM '{{name}}'@'%'; - DROP USER '{{name}}'@'%' -` -const mySQLTypeName = "mysql" +const ( + defaultMysqlRevocationStmts = ` + REVOKE ALL PRIVILEGES, GRANT OPTION FROM '{{name}}'@'%'; + DROP USER '{{name}}'@'%' + ` + mySQLTypeName = "mysql" +) + +var ( + DisplayNameLen int = 10 + LegacyDisplayNameLen int = 4 + UsernameLen int = 32 + LegacyUsernameLen int = 16 +) type MySQL struct { connutil.ConnectionProducer @@ -26,26 +35,29 @@ type MySQL struct { } // New implements builtinplugins.BuiltinFactory -func New() (interface{}, error) { - connProducer := &connutil.SQLConnectionProducer{} - connProducer.Type = mySQLTypeName +func New(displayLen, usernameLen int) func() (interface{}, error) { + return func() (interface{}, error) { + connProducer := &connutil.SQLConnectionProducer{} + connProducer.Type = mySQLTypeName + + credsProducer := &credsutil.SQLCredentialsProducer{ + DisplayNameLen: displayLen, + UsernameLen: usernameLen, + } - credsProducer := &credsutil.SQLCredentialsProducer{ - DisplayNameLen: 4, - UsernameLen: 16, - } + dbType := &MySQL{ + ConnectionProducer: connProducer, + CredentialsProducer: credsProducer, + } - dbType := &MySQL{ - ConnectionProducer: connProducer, - CredentialsProducer: credsProducer, + return dbType, nil } - - return dbType, nil } // Run instantiates a MySQL object, and runs the RPC server for the plugin func Run(apiTLSConfig *api.TLSConfig) error { - dbType, err := New() + f := New(DisplayNameLen, UsernameLen) + dbType, err := f() if err != nil { return err } From 5b8ce92e1276eb36f35712b4d33e247e1ed74ca6 Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Wed, 3 May 2017 13:36:16 -0700 Subject: [PATCH 141/152] Fix mysql plugin tests --- plugins/database/mysql/mysql_test.go | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/plugins/database/mysql/mysql_test.go b/plugins/database/mysql/mysql_test.go index c86f9c2f6b1d..72dbd81560ad 100644 --- a/plugins/database/mysql/mysql_test.go +++ b/plugins/database/mysql/mysql_test.go @@ -66,7 +66,8 @@ func TestMySQL_Initialize(t *testing.T) { "connection_url": connURL, } - dbRaw, _ := New() + f := New(DisplayNameLen, UsernameLen) + dbRaw, _ := f() db := dbRaw.(*MySQL) connProducer := db.ConnectionProducer.(*connutil.SQLConnectionProducer) @@ -93,7 +94,8 @@ func TestMySQL_CreateUser(t *testing.T) { "connection_url": connURL, } - dbRaw, _ := New() + f := New(DisplayNameLen, UsernameLen) + dbRaw, _ := f() db := dbRaw.(*MySQL) err := db.Initialize(connectionDetails, true) @@ -129,7 +131,8 @@ func TestMySQL_RevokeUser(t *testing.T) { "connection_url": connURL, } - dbRaw, _ := New() + f := New(DisplayNameLen, UsernameLen) + dbRaw, _ := f() db := dbRaw.(*MySQL) err := db.Initialize(connectionDetails, true) From 3ca266b4751f8ae30604faf4355560d57730bb67 Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Wed, 3 May 2017 13:45:27 -0700 Subject: [PATCH 142/152] Fix parsing the connection duration when it's nil --- plugins/helper/database/connutil/cassandra.go | 3 +++ plugins/helper/database/connutil/sql.go | 2 +- 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/plugins/helper/database/connutil/cassandra.go b/plugins/helper/database/connutil/cassandra.go index 27fb2519587b..958bef201077 100644 --- a/plugins/helper/database/connutil/cassandra.go +++ b/plugins/helper/database/connutil/cassandra.go @@ -48,6 +48,9 @@ func (c *CassandraConnectionProducer) Initialize(conf map[string]interface{}, ve } c.Initialized = true + if c.ConnectTimeoutRaw == nil { + c.ConnectTimeoutRaw = "0s" + } c.connectTimeout, err = parseutil.ParseDurationSecond(c.ConnectTimeoutRaw) if err != nil { return fmt.Errorf("invalid connect_timeout: %s", err) diff --git a/plugins/helper/database/connutil/sql.go b/plugins/helper/database/connutil/sql.go index 4a636856037d..5067e10d7cc6 100644 --- a/plugins/helper/database/connutil/sql.go +++ b/plugins/helper/database/connutil/sql.go @@ -48,7 +48,7 @@ func (c *SQLConnectionProducer) Initialize(conf map[string]interface{}, verifyCo if c.MaxIdleConnections > c.MaxOpenConnections { c.MaxIdleConnections = c.MaxOpenConnections } - if c.MaxConnectionLifetimeRaw == "" { + if c.MaxConnectionLifetimeRaw == nil { c.MaxConnectionLifetimeRaw = "0s" } From 3fcf1ad44262e99add927b6ffebacc3927e782f1 Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Wed, 3 May 2017 15:36:49 -0700 Subject: [PATCH 143/152] Fix the TLS functionality in cassandra plugin --- plugins/helper/database/connutil/cassandra.go | 70 +++++++++++++++---- plugins/helper/database/connutil/sql.go | 4 ++ 2 files changed, 61 insertions(+), 13 deletions(-) diff --git a/plugins/helper/database/connutil/cassandra.go b/plugins/helper/database/connutil/cassandra.go index 958bef201077..869c39e3b6dc 100644 --- a/plugins/helper/database/connutil/cassandra.go +++ b/plugins/helper/database/connutil/cassandra.go @@ -23,18 +23,21 @@ type CassandraConnectionProducer struct { Password string `json:"password" structs:"password" mapstructure:"password"` TLS bool `json:"tls" structs:"tls" mapstructure:"tls"` InsecureTLS bool `json:"insecure_tls" structs:"insecure_tls" mapstructure:"insecure_tls"` - Certificate string `json:"certificate" structs:"certificate" mapstructure:"certificate"` - PrivateKey string `json:"private_key" structs:"private_key" mapstructure:"private_key"` - IssuingCA string `json:"issuing_ca" structs:"issuing_ca" mapstructure:"issuing_ca"` ProtocolVersion int `json:"protocol_version" structs:"protocol_version" mapstructure:"protocol_version"` ConnectTimeoutRaw interface{} `json:"connect_timeout" structs:"connect_timeout" mapstructure:"connect_timeout"` TLSMinVersion string `json:"tls_min_version" structs:"tls_min_version" mapstructure:"tls_min_version"` Consistency string `json:"consistency" structs:"consistency" mapstructure:"consistency"` + PemBundle string `json:"pem_bundle" structs:"pem_bundle" mapstructure:"pem_bundle"` + PemJSON string `json:"pem_json" structs:"pem_json" mapstructure:"pem_json"` connectTimeout time.Duration - Initialized bool - Type string - session *gocql.Session + certificate string + privateKey string + issuingCA string + + Initialized bool + Type string + session *gocql.Session sync.Mutex } @@ -56,6 +59,47 @@ func (c *CassandraConnectionProducer) Initialize(conf map[string]interface{}, ve return fmt.Errorf("invalid connect_timeout: %s", err) } + switch { + case len(c.Hosts) == 0: + return fmt.Errorf("hosts cannot be empty") + case len(c.Username) == 0: + return fmt.Errorf("username cannot be empty") + case len(c.Password) == 0: + return fmt.Errorf("password cannot be empty") + } + + var certBundle *certutil.CertBundle + var parsedCertBundle *certutil.ParsedCertBundle + switch { + case len(c.PemJSON) != 0: + parsedCertBundle, err = certutil.ParsePKIJSON([]byte(c.PemJSON)) + if err != nil { + return fmt.Errorf("could not parse given JSON; it must be in the format of the output of the PKI backend certificate issuing command: %s", err) + } + certBundle, err = parsedCertBundle.ToCertBundle() + if err != nil { + return fmt.Errorf("Error marshaling PEM information: %s", err) + } + c.certificate = certBundle.Certificate + c.privateKey = certBundle.PrivateKey + c.issuingCA = certBundle.IssuingCA + c.TLS = true + + case len(c.PemBundle) != 0: + parsedCertBundle, err = certutil.ParsePEMBundle(c.PemBundle) + if err != nil { + return fmt.Errorf("Error parsing the given PEM information: %s", err) + } + certBundle, err = parsedCertBundle.ToCertBundle() + if err != nil { + return fmt.Errorf("Error marshaling PEM information: %s", err) + } + c.certificate = certBundle.Certificate + c.privateKey = certBundle.PrivateKey + c.issuingCA = certBundle.IssuingCA + c.TLS = true + } + if verifyConnection { if _, err := c.Connection(); err != nil { return fmt.Errorf("error Initalizing Connection: %s", err) @@ -114,18 +158,18 @@ func (c *CassandraConnectionProducer) createSession() (*gocql.Session, error) { clusterConfig.Timeout = c.connectTimeout if c.TLS { var tlsConfig *tls.Config - if len(c.Certificate) > 0 || len(c.IssuingCA) > 0 { - if len(c.Certificate) > 0 && len(c.PrivateKey) == 0 { + if len(c.certificate) > 0 || len(c.issuingCA) > 0 { + if len(c.certificate) > 0 && len(c.privateKey) == 0 { return nil, fmt.Errorf("found certificate for TLS authentication but no private key") } certBundle := &certutil.CertBundle{} - if len(c.Certificate) > 0 { - certBundle.Certificate = c.Certificate - certBundle.PrivateKey = c.PrivateKey + if len(c.certificate) > 0 { + certBundle.Certificate = c.certificate + certBundle.PrivateKey = c.privateKey } - if len(c.IssuingCA) > 0 { - certBundle.IssuingCA = c.IssuingCA + if len(c.issuingCA) > 0 { + certBundle.IssuingCA = c.issuingCA } parsedCertBundle, err := certBundle.ToParsedCertBundle() diff --git a/plugins/helper/database/connutil/sql.go b/plugins/helper/database/connutil/sql.go index 5067e10d7cc6..04269798f190 100644 --- a/plugins/helper/database/connutil/sql.go +++ b/plugins/helper/database/connutil/sql.go @@ -38,6 +38,10 @@ func (c *SQLConnectionProducer) Initialize(conf map[string]interface{}, verifyCo return err } + if len(c.ConnectionURL) == 0 { + return fmt.Errorf("connection_url cannot be empty") + } + if c.MaxOpenConnections == 0 { c.MaxOpenConnections = 2 } From a3619c452191ce734847da6338ee1f7ce9bd7966 Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Wed, 3 May 2017 16:34:09 -0700 Subject: [PATCH 144/152] Update databse backend tests to use the APIClientMeta for the plugin conns --- builtin/logical/database/backend_test.go | 23 ++++++++++++++++--- .../logical/database/dbplugin/plugin_test.go | 9 ++++++-- 2 files changed, 27 insertions(+), 5 deletions(-) diff --git a/builtin/logical/database/backend_test.go b/builtin/logical/database/backend_test.go index 70ec22ee2fed..27c20d33258e 100644 --- a/builtin/logical/database/backend_test.go +++ b/builtin/logical/database/backend_test.go @@ -3,6 +3,7 @@ package database import ( "database/sql" "fmt" + "io/ioutil" "log" stdhttp "net/http" "os" @@ -10,7 +11,6 @@ import ( "sync" "testing" - "github.com/hashicorp/vault/api" "github.com/hashicorp/vault/builtin/logical/database/dbplugin" "github.com/hashicorp/vault/helper/pluginutil" "github.com/hashicorp/vault/http" @@ -109,11 +109,28 @@ func TestBackend_PluginMain(t *testing.T) { return } - err := postgresql.Run(&api.TLSConfig{Insecure: true}) + content := []byte(vault.TestClusterCACert) + tmpfile, err := ioutil.TempFile("", "example") if err != nil { t.Fatal(err) } - t.Fatal("We shouldn't get here") + + defer os.Remove(tmpfile.Name()) // clean up + + if _, err := tmpfile.Write(content); err != nil { + t.Fatal(err) + } + if err := tmpfile.Close(); err != nil { + t.Fatal(err) + } + + args := []string{"--ca-cert=" + tmpfile.Name()} + + apiClientMeta := &pluginutil.APIClientMeta{} + flags := apiClientMeta.FlagSet() + flags.Parse(args) + + postgresql.Run(apiClientMeta.GetTLSConfig()) } func TestBackend_config_connection(t *testing.T) { diff --git a/builtin/logical/database/dbplugin/plugin_test.go b/builtin/logical/database/dbplugin/plugin_test.go index c38d85ed3969..c95e119e0ca0 100644 --- a/builtin/logical/database/dbplugin/plugin_test.go +++ b/builtin/logical/database/dbplugin/plugin_test.go @@ -7,7 +7,6 @@ import ( "testing" "time" - "github.com/hashicorp/vault/api" "github.com/hashicorp/vault/builtin/logical/database/dbplugin" "github.com/hashicorp/vault/helper/pluginutil" "github.com/hashicorp/vault/http" @@ -107,7 +106,13 @@ func TestPlugin_Main(t *testing.T) { users: make(map[string][]string), } - plugins.Serve(plugin, &api.TLSConfig{Insecure: true}) + args := []string{"--tls-skip-verify=true"} + + apiClientMeta := &pluginutil.APIClientMeta{} + flags := apiClientMeta.FlagSet() + flags.Parse(args) + + plugins.Serve(plugin, apiClientMeta.GetTLSConfig()) } func TestPlugin_Initialize(t *testing.T) { From 2af2b855f572c55bd8940fe33d6d259813cbc915 Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Wed, 3 May 2017 17:37:34 -0700 Subject: [PATCH 145/152] Feedback from PR --- builtin/logical/database/backend.go | 8 ++--- .../database/path_config_connection.go | 4 +-- helper/pluginutil/runner.go | 6 ++-- helper/pluginutil/tls.go | 30 +++++++++---------- 4 files changed, 23 insertions(+), 25 deletions(-) diff --git a/builtin/logical/database/backend.go b/builtin/logical/database/backend.go index da8c8384acd2..3d1502805128 100644 --- a/builtin/logical/database/backend.go +++ b/builtin/logical/database/backend.go @@ -54,7 +54,7 @@ type databaseBackend struct { sync.RWMutex } -// resetAllDBs closes all connections from all database types +// closeAllDBs closes all connections from all database types func (b *databaseBackend) closeAllDBs() { b.Lock() defer b.Unlock() @@ -120,8 +120,8 @@ func (b *databaseBackend) DatabaseConfig(s logical.Storage, name string) (*Datab return &config, nil } -func (b *databaseBackend) Role(s logical.Storage, n string) (*roleEntry, error) { - entry, err := s.Get("role/" + n) +func (b *databaseBackend) Role(s logical.Storage, roleName string) (*roleEntry, error) { + entry, err := s.Get("role/" + roleName) if err != nil { return nil, err } @@ -170,7 +170,7 @@ func (b *databaseBackend) closeIfShutdown(name string, err error) { const backendHelp = ` The database backend supports using many different databases as secret backends, including but not limited to: -cassandra, msslq, mysql, postgres +cassandra, mssql, mysql, postgres After mounting this backend, configure it using the endpoints within the "database/config/" path. diff --git a/builtin/logical/database/path_config_connection.go b/builtin/logical/database/path_config_connection.go index f3767428586e..e84212bb89a8 100644 --- a/builtin/logical/database/path_config_connection.go +++ b/builtin/logical/database/path_config_connection.go @@ -115,7 +115,7 @@ func pathConfigurePluginConnection(b *databaseBackend) *framework.Path { } } -// pathConnectionRead reads out the connection configuration +// connectionReadHandler reads out the connection configuration func (b *databaseBackend) connectionReadHandler() framework.OperationFunc { return func(req *logical.Request, data *framework.FieldData) (*logical.Response, error) { name := data.Get("name").(string) @@ -248,7 +248,7 @@ database. This path runs the provided plugin name and passes the configured connection details to the plugin. See the documentation for the plugin specified for a full list of accepted connection details. -In addition to the database specific connection details, this endpoing also +In addition to the database specific connection details, this endpoint also accepts: * "plugin_name" (required) - The name of a builtin or previously registered diff --git a/helper/pluginutil/runner.go b/helper/pluginutil/runner.go index 9dbe5c51bb7b..4b25ba16bb09 100644 --- a/helper/pluginutil/runner.go +++ b/helper/pluginutil/runner.go @@ -47,20 +47,20 @@ type PluginRunner struct { // plugin. func (r *PluginRunner) Run(wrapper RunnerUtil, pluginMap map[string]plugin.Plugin, hs plugin.HandshakeConfig, env []string) (*plugin.Client, error) { // Get a CA TLS Certificate - certBytes, key, err := GenerateCert() + certBytes, key, err := generateCert() if err != nil { return nil, err } // Use CA to sign a client cert and return a configured TLS config - clientTLSConfig, err := CreateClientTLSConfig(certBytes, key) + clientTLSConfig, err := createClientTLSConfig(certBytes, key) if err != nil { return nil, err } // Use CA to sign a server cert and wrap the values in a response wrapped // token. - wrapToken, err := WrapServerConfig(wrapper, certBytes, key) + wrapToken, err := wrapServerConfig(wrapper, certBytes, key) if err != nil { return nil, err } diff --git a/helper/pluginutil/tls.go b/helper/pluginutil/tls.go index b355079d6e24..1a7fbe78308f 100644 --- a/helper/pluginutil/tls.go +++ b/helper/pluginutil/tls.go @@ -10,17 +10,15 @@ import ( "encoding/base64" "errors" "fmt" - "math/big" - mathrand "math/rand" "net/url" "os" - "strings" "time" "github.com/SermoDigital/jose/jws" "github.com/hashicorp/errwrap" uuid "github.com/hashicorp/go-uuid" "github.com/hashicorp/vault/api" + "github.com/hashicorp/vault/helper/certutil" ) var ( @@ -29,9 +27,9 @@ var ( PluginUnwrapTokenEnv = "VAULT_UNWRAP_TOKEN" ) -// generateSignedCert is used internally to create certificates for the plugin -// client and server. These certs are signed by the given CA Cert and Key. -func GenerateCert() ([]byte, *ecdsa.PrivateKey, error) { +// generateCert is used internally to create certificates for the plugin +// client and server. +func generateCert() ([]byte, *ecdsa.PrivateKey, error) { key, err := ecdsa.GenerateKey(elliptic.P521(), rand.Reader) if err != nil { return nil, nil, err @@ -42,6 +40,11 @@ func GenerateCert() ([]byte, *ecdsa.PrivateKey, error) { return nil, nil, err } + sn, err := certutil.GenerateSerialNumber() + if err != nil { + return nil, nil, err + } + template := &x509.Certificate{ Subject: pkix.Name{ CommonName: host, @@ -52,7 +55,7 @@ func GenerateCert() ([]byte, *ecdsa.PrivateKey, error) { x509.ExtKeyUsageServerAuth, }, KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment | x509.KeyUsageKeyAgreement, - SerialNumber: big.NewInt(mathrand.Int63()), + SerialNumber: sn, NotBefore: time.Now().Add(-30 * time.Second), NotAfter: time.Now().Add(262980 * time.Hour), IsCA: true, @@ -66,9 +69,9 @@ func GenerateCert() ([]byte, *ecdsa.PrivateKey, error) { return certBytes, key, nil } -// CreateClientTLSConfig creates a signed certificate and returns a configured +// createClientTLSConfig creates a signed certificate and returns a configured // TLS config. -func CreateClientTLSConfig(certBytes []byte, key *ecdsa.PrivateKey) (*tls.Config, error) { +func createClientTLSConfig(certBytes []byte, key *ecdsa.PrivateKey) (*tls.Config, error) { clientCert, err := x509.ParseCertificate(certBytes) if err != nil { return nil, fmt.Errorf("error parsing generated plugin certificate: %v", err) @@ -95,9 +98,9 @@ func CreateClientTLSConfig(certBytes []byte, key *ecdsa.PrivateKey) (*tls.Config return tlsConfig, nil } -// WrapServerConfig is used to create a server certificate and private key, then +// wrapServerConfig is used to create a server certificate and private key, then // wrap them in an unwrap token for later retrieval by the plugin. -func WrapServerConfig(sys RunnerUtil, certBytes []byte, key *ecdsa.PrivateKey) (string, error) { +func wrapServerConfig(sys RunnerUtil, certBytes []byte, key *ecdsa.PrivateKey) (string, error) { rawKey, err := x509.MarshalECPrivateKey(key) if err != nil { return "", err @@ -120,11 +123,6 @@ func VaultPluginTLSProvider(apiTLSConfig *api.TLSConfig) func() (*tls.Config, er return func() (*tls.Config, error) { unwrapToken := os.Getenv(PluginUnwrapTokenEnv) - // Ensure unwrap token is a JWT - if strings.Count(unwrapToken, ".") != 2 { - return nil, errors.New("Could not parse unwraptoken") - } - // Parse the JWT and retrieve the vault address wt, err := jws.ParseJWT([]byte(unwrapToken)) if err != nil { From 9e28b03c9b5c2a420ab5d88200d6a9a2cabf9ee8 Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Wed, 3 May 2017 18:41:39 -0700 Subject: [PATCH 146/152] add new mysql plugin names and fix grammar --- .../docs/secrets/databases/cassandra.html.md | 2 +- .../source/docs/secrets/databases/mssql.html.md | 2 +- .../docs/secrets/databases/mysql-maria.html.md | 15 +++++++++++++-- .../docs/secrets/databases/postgresql.html.md | 2 +- 4 files changed, 16 insertions(+), 5 deletions(-) diff --git a/website/source/docs/secrets/databases/cassandra.html.md b/website/source/docs/secrets/databases/cassandra.html.md index 1d8468ad3a71..0e29d0300b7c 100644 --- a/website/source/docs/secrets/databases/cassandra.html.md +++ b/website/source/docs/secrets/databases/cassandra.html.md @@ -57,6 +57,6 @@ This role can be used to retrieve a new set of credentials by querying the The full list of configurable options can be seen in the [Cassandra database plugin API](/api/secret/databases/cassandra.html) page. -Or for more information on the Database secret backend's HTTP API please see the [Database secret +For more information on the Database secret backend's HTTP API please see the [Database secret backend API](/api/secret/databases/index.html) page. diff --git a/website/source/docs/secrets/databases/mssql.html.md b/website/source/docs/secrets/databases/mssql.html.md index fec8924b3fa8..c2f7ff5fe3cf 100644 --- a/website/source/docs/secrets/databases/mssql.html.md +++ b/website/source/docs/secrets/databases/mssql.html.md @@ -55,6 +55,6 @@ This role can now be used to retrieve a new set of credentials by querying the The full list of configurable options can be seen in the [MSSQL database plugin API](/api/secret/databases/mssql.html) page. -Or for more information on the Database secret backend's HTTP API please see the [Database secret +For more information on the Database secret backend's HTTP API please see the [Database secret backend API](/api/secret/databases/index.html) page. diff --git a/website/source/docs/secrets/databases/mysql-maria.html.md b/website/source/docs/secrets/databases/mysql-maria.html.md index c5eea4b7bac9..ae6c19eac284 100644 --- a/website/source/docs/secrets/databases/mysql-maria.html.md +++ b/website/source/docs/secrets/databases/mysql-maria.html.md @@ -8,7 +8,8 @@ description: |- # MySQL/MariaDB Database Plugin -Name: `mysql-database-plugin` +Name: `mysql-database-plugin`, `aurora-database-plugin`, `rds-database-plugin`, +`mysql-legacy-database-plugin` The MySQL Database Plugin is one of the supported plugins for the Database backend. This plugin generates database credentials dynamically based on @@ -17,6 +18,16 @@ configured roles for the MySQL database. See the [Database Backend](/docs/secrets/databases/index.html) docs for more information about setting up the Database Backend. +This plugin has a few different instances built into vault, each instance is for +a slightly different MySQL driver. The only difference between these plugins is +the length of usernames generated by the plugin as different versions of mysql +accept different lengths. The availible plugins are: + + - mysql-database-plugin + - aurora-database-plugin + - rds-database-plugin + - mysql-legacy-database-plugin + ## Quick Start After the Database Backend is mounted you can configure a MySQL connection @@ -53,6 +64,6 @@ This role can now be used to retrieve a new set of credentials by querying the The full list of configurable options can be seen in the [MySQL database plugin API](/api/secret/databases/mysql-maria.html) page. -Or for more information on the Database secret backend's HTTP API please see the [Database secret +For more information on the Database secret backend's HTTP API please see the [Database secret backend API](/api/secret/databases/index.html) page. diff --git a/website/source/docs/secrets/databases/postgresql.html.md b/website/source/docs/secrets/databases/postgresql.html.md index 72601e34fc4c..e04cc087c19d 100644 --- a/website/source/docs/secrets/databases/postgresql.html.md +++ b/website/source/docs/secrets/databases/postgresql.html.md @@ -55,6 +55,6 @@ This role can be used to retrieve a new set of credentials by querying the The full list of configurable options can be seen in the [PostgreSQL database plugin API](/api/secret/databases/postgresql.html) page. -Or for more information on the Database secret backend's HTTP API please see the [Database secret +For more information on the Database secret backend's HTTP API please see the [Database secret backend API](/api/secret/databases/index.html) page. From c825362304b7d36f6af85b88bde39e7ffb384fcc Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Thu, 4 May 2017 10:41:59 -0700 Subject: [PATCH 147/152] PR comments --- builtin/logical/database/path_roles.go | 2 +- helper/builtinplugins/builtin.go | 4 ++-- logical/system_view.go | 2 +- website/source/docs/secrets/databases/mysql-maria.html.md | 6 +++--- 4 files changed, 7 insertions(+), 7 deletions(-) diff --git a/builtin/logical/database/path_roles.go b/builtin/logical/database/path_roles.go index c81261804641..8be33c0a1215 100644 --- a/builtin/logical/database/path_roles.go +++ b/builtin/logical/database/path_roles.go @@ -57,7 +57,7 @@ func pathRoles(b *databaseBackend) *framework.Path { }, "rollback_statements": { Type: framework.TypeString, - Description: `SQL statements to be executed to revoke a user. Must be a semicolon-separated + Description: `Statements to be executed to revoke a user. Must be a semicolon-separated string, a base64-encoded semicolon-separated string, a serialized JSON string array, or a base64-encoded serialized JSON string array. The '{{name}}' value will be substituted.`, diff --git a/helper/builtinplugins/builtin.go b/helper/builtinplugins/builtin.go index 8e6ed22ef674..a2100e931872 100644 --- a/helper/builtinplugins/builtin.go +++ b/helper/builtinplugins/builtin.go @@ -13,8 +13,8 @@ var plugins map[string]BuiltinFactory = map[string]BuiltinFactory{ // These four plugins all use the same mysql implementation but with // different username settings passed by the constructor. "mysql-database-plugin": mysql.New(mysql.DisplayNameLen, mysql.UsernameLen), - "aurora-database-plugin": mysql.New(mysql.LegacyDisplayNameLen, mysql.LegacyUsernameLen), - "rds-database-plugin": mysql.New(mysql.LegacyDisplayNameLen, mysql.LegacyUsernameLen), + "mysql-aurora-database-plugin": mysql.New(mysql.LegacyDisplayNameLen, mysql.LegacyUsernameLen), + "mysql-rds-database-plugin": mysql.New(mysql.LegacyDisplayNameLen, mysql.LegacyUsernameLen), "mysql-legacy-database-plugin": mysql.New(mysql.LegacyDisplayNameLen, mysql.LegacyUsernameLen), "postgresql-database-plugin": postgresql.New, diff --git a/logical/system_view.go b/logical/system_view.go index 175edc0f9a40..64fc51c7b477 100644 --- a/logical/system_view.go +++ b/logical/system_view.go @@ -49,7 +49,7 @@ type SystemView interface { // name. Returns a PluginRunner or an error if a plugin can not be found. LookupPlugin(string) (*pluginutil.PluginRunner, error) - // MlockEnabled returns the configuration setting for Enableing mlock on + // MlockEnabled returns the configuration setting for enabling mlock on // plugins. MlockEnabled() bool } diff --git a/website/source/docs/secrets/databases/mysql-maria.html.md b/website/source/docs/secrets/databases/mysql-maria.html.md index ae6c19eac284..f4cf3640ba92 100644 --- a/website/source/docs/secrets/databases/mysql-maria.html.md +++ b/website/source/docs/secrets/databases/mysql-maria.html.md @@ -8,7 +8,7 @@ description: |- # MySQL/MariaDB Database Plugin -Name: `mysql-database-plugin`, `aurora-database-plugin`, `rds-database-plugin`, +Name: `mysql-database-plugin`, `mysql-aurora-database-plugin`, `mysql-rds-database-plugin`, `mysql-legacy-database-plugin` The MySQL Database Plugin is one of the supported plugins for the Database @@ -24,8 +24,8 @@ the length of usernames generated by the plugin as different versions of mysql accept different lengths. The availible plugins are: - mysql-database-plugin - - aurora-database-plugin - - rds-database-plugin + - mysql-aurora-database-plugin + - mysql-rds-database-plugin - mysql-legacy-database-plugin ## Quick Start From 886f873ffcf0445c403205a57b48f2b2ff5d2d2a Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Thu, 4 May 2017 11:45:27 -0700 Subject: [PATCH 148/152] Update docs and return a better error message --- builtin/logical/database/backend.go | 2 +- website/source/api/secret/databases/index.html.md | 10 ++++++---- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/builtin/logical/database/backend.go b/builtin/logical/database/backend.go index 3d1502805128..91b92e438ae5 100644 --- a/builtin/logical/database/backend.go +++ b/builtin/logical/database/backend.go @@ -106,7 +106,7 @@ func (b *databaseBackend) createDBObj(s logical.Storage, name string) (dbplugin. func (b *databaseBackend) DatabaseConfig(s logical.Storage, name string) (*DatabaseConfig, error) { entry, err := s.Get(fmt.Sprintf("config/%s", name)) if err != nil { - return nil, fmt.Errorf("failed to read connection configuration with name: %s", name) + return nil, fmt.Errorf("failed to read connection configuration: %s", err) } if entry == nil { return nil, fmt.Errorf("failed to find entry for connection with name: %s", name) diff --git a/website/source/api/secret/databases/index.html.md b/website/source/api/secret/databases/index.html.md index f55998aceb8a..d43e49789aad 100644 --- a/website/source/api/secret/databases/index.html.md +++ b/website/source/api/secret/databases/index.html.md @@ -162,11 +162,13 @@ This endpoint creates or updates a role definition. - `db_name` `(string: )` - The name of the database connection to use for this role. -- `default_ttl` `(string: )` - Specifies the TTL for the lease - associated with this role. +- `default_ttl` `(string/int: 0)` - Specifies the TTL for the leases + associated with this role. Accepts time suffixed strings ("1h") or an integer + number of seconds. Defaults to system/backend default TTL time. -- `max_ttl` `(string: )` - Specifies the maximum TTL for the lease - associated with this role. +- `max_ttl` `(string/int: 0)` - Specifies the maximum TTL for the leases + associated with this role. Accepts time suffixed strings ("1h") or an integer + number of seconds. Defaults to system/backend default TTL time. - `creation_statements` `(string: )` – Specifies the database statements executed to create and configure a user. Must be a From 17bea6540bfb800a91463bbaaf992bf402521645 Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Thu, 4 May 2017 12:36:06 -0700 Subject: [PATCH 149/152] Don't store the plugin directory prepended command in the barrier, prepend on get --- vault/plugin_catalog.go | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/vault/plugin_catalog.go b/vault/plugin_catalog.go index 095d81b1e4f5..79474e601fb3 100644 --- a/vault/plugin_catalog.go +++ b/vault/plugin_catalog.go @@ -59,6 +59,9 @@ func (c *PluginCatalog) Get(name string) (*pluginutil.PluginRunner, error) { return nil, fmt.Errorf("failed to decode plugin entry: %v", err) } + // prepend the plugin directory to the command + entry.Command = filepath.Join(c.directory, entry.Command) + return entry, nil } } @@ -85,14 +88,11 @@ func (c *PluginCatalog) Set(name, command string, sha256 []byte) error { defer c.lock.Unlock() parts := strings.Split(command, " ") - command = parts[0] - args := parts[1:] - - command = filepath.Join(c.directory, command) // Best effort check to make sure the command isn't breaking out of the // configured plugin directory. - sym, err := filepath.EvalSymlinks(command) + commandFull := filepath.Join(c.directory, parts[0]) + sym, err := filepath.EvalSymlinks(commandFull) if err != nil { return fmt.Errorf("error while validating the command path: %v", err) } @@ -107,8 +107,8 @@ func (c *PluginCatalog) Set(name, command string, sha256 []byte) error { entry := &pluginutil.PluginRunner{ Name: name, - Command: command, - Args: args, + Command: parts[0], + Args: parts[1:], Sha256: sha256, Builtin: false, } From 2e82e00f49c3f31cd54138aa2a3debf105294288 Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Thu, 4 May 2017 13:38:49 -0700 Subject: [PATCH 150/152] update docs --- website/source/docs/secrets/databases/mysql-maria.html.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/website/source/docs/secrets/databases/mysql-maria.html.md b/website/source/docs/secrets/databases/mysql-maria.html.md index f4cf3640ba92..0d7c497480f8 100644 --- a/website/source/docs/secrets/databases/mysql-maria.html.md +++ b/website/source/docs/secrets/databases/mysql-maria.html.md @@ -15,9 +15,6 @@ The MySQL Database Plugin is one of the supported plugins for the Database backend. This plugin generates database credentials dynamically based on configured roles for the MySQL database. -See the [Database Backend](/docs/secrets/databases/index.html) docs for more -information about setting up the Database Backend. - This plugin has a few different instances built into vault, each instance is for a slightly different MySQL driver. The only difference between these plugins is the length of usernames generated by the plugin as different versions of mysql @@ -28,6 +25,9 @@ accept different lengths. The availible plugins are: - mysql-rds-database-plugin - mysql-legacy-database-plugin +See the [Database Backend](/docs/secrets/databases/index.html) docs for more +information about setting up the Database Backend. + ## Quick Start After the Database Backend is mounted you can configure a MySQL connection From 65b7bba360dfb3df8e438b65efd3ae54d7ca1d61 Mon Sep 17 00:00:00 2001 From: Calvin Leung Huang Date: Thu, 4 May 2017 16:46:34 -0400 Subject: [PATCH 151/152] Update mssql docs --- website/source/docs/secrets/databases/mssql.html.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/website/source/docs/secrets/databases/mssql.html.md b/website/source/docs/secrets/databases/mssql.html.md index c2f7ff5fe3cf..63ea31c448e6 100644 --- a/website/source/docs/secrets/databases/mssql.html.md +++ b/website/source/docs/secrets/databases/mssql.html.md @@ -39,11 +39,11 @@ Once the MSSQL connection is configured we can add a role: $ vault write database/roles/readonly \ db_name=mssql \ creation_statements="CREATE LOGIN [{{name}}] WITH PASSWORD = '{{password}}';\ - USE AdventureWorks; CREATE USER [{{name}}] FOR LOGIN [{{name}}]; \ - GRANT SELECT ON SCHEMA::dbo TO [{{name}}];" \ + CREATE USER [{{name}}] FOR LOGIN [{{name}}];\ + GRANT SELECT, INSERT, UPDATE, DELETE ON SCHEMA::dbo TO [{{name}}];" \ default_ttl="1h" \ max_ttl="24h" - + Success! Data written to: database/roles/readonly ``` From c48b7fa8db6c1458d5a2b4d07a3f0cf03a8ec754 Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Thu, 4 May 2017 14:07:12 -0700 Subject: [PATCH 152/152] Few docs updates --- website/source/docs/secrets/databases/mssql.html.md | 2 +- website/source/docs/secrets/databases/postgresql.html.md | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/website/source/docs/secrets/databases/mssql.html.md b/website/source/docs/secrets/databases/mssql.html.md index 63ea31c448e6..889e35a4fa02 100644 --- a/website/source/docs/secrets/databases/mssql.html.md +++ b/website/source/docs/secrets/databases/mssql.html.md @@ -40,7 +40,7 @@ $ vault write database/roles/readonly \ db_name=mssql \ creation_statements="CREATE LOGIN [{{name}}] WITH PASSWORD = '{{password}}';\ CREATE USER [{{name}}] FOR LOGIN [{{name}}];\ - GRANT SELECT, INSERT, UPDATE, DELETE ON SCHEMA::dbo TO [{{name}}];" \ + GRANT SELECT ON SCHEMA::dbo TO [{{name}}];" \ default_ttl="1h" \ max_ttl="24h" diff --git a/website/source/docs/secrets/databases/postgresql.html.md b/website/source/docs/secrets/databases/postgresql.html.md index e04cc087c19d..b2c0c7bb295b 100644 --- a/website/source/docs/secrets/databases/postgresql.html.md +++ b/website/source/docs/secrets/databases/postgresql.html.md @@ -27,7 +27,7 @@ configuration: $ vault write database/config/postgresql \ plugin_name=postgresql-database-plugin \ allowed_roles="readonly" \ - connection_url="postgresql://root:root@localhost:5432/postgres" + connection_url="postgresql://root:root@localhost:5432/" The following warnings were returned from the Vault server: * Read access to this endpoint should be controlled via ACLs as it will return the connection details as is, including passwords, if any.