Skip to content

Commit

Permalink
Mongodb plugin (#2698)
Browse files Browse the repository at this point in the history
* WIP on mongodb plugin

* Add mongodb plugin

* Add tests

* Update mongodb.CreateUser() comment

* Update docs

* Add missing docs

* Fix mongodb docs

* Minor comment and test updates

* Fix imports

* Fix dockertest import

* Set c.Initialized at the end, check for empty CreationStmts first on CreateUser

* Remove Initialized check on Connection()

* Add back Initialized check

* Update docs

* Move connProducer and credsProducer into pkg for  mongodb and cassandra

* Chage parseMongoURL to be a private func

* Default to admin if no db is provided in creation_statements

* Update comments and docs
  • Loading branch information
calvn authored May 11, 2017
1 parent b203d51 commit a4c652c
Show file tree
Hide file tree
Showing 20 changed files with 809 additions and 34 deletions.
4 changes: 2 additions & 2 deletions plugins/database/cassandra/cassandra.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,10 @@ type Cassandra struct {

// New returns a new Cassandra instance
func New() (interface{}, error) {
connProducer := &connutil.CassandraConnectionProducer{}
connProducer := &cassandraConnectionProducer{}
connProducer.Type = cassandraTypeName

credsProducer := &credsutil.CassandraCredentialsProducer{}
credsProducer := &cassandraCredentialsProducer{}

dbType := &Cassandra{
ConnectionProducer: connProducer,
Expand Down
3 changes: 1 addition & 2 deletions plugins/database/cassandra/cassandra_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ import (

"github.com/gocql/gocql"
"github.com/hashicorp/vault/builtin/logical/database/dbplugin"
"github.com/hashicorp/vault/plugins/helper/database/connutil"
dockertest "gopkg.in/ory-am/dockertest.v3"
)

Expand Down Expand Up @@ -85,7 +84,7 @@ func TestCassandra_Initialize(t *testing.T) {

dbRaw, _ := New()
db := dbRaw.(*Cassandra)
connProducer := db.ConnectionProducer.(*connutil.CassandraConnectionProducer)
connProducer := db.ConnectionProducer.(*cassandraConnectionProducer)

err := db.Initialize(connectionDetails, true)
if err != nil {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package connutil
package cassandra

import (
"crypto/tls"
Expand All @@ -13,11 +13,12 @@ import (
"github.com/hashicorp/vault/helper/certutil"
"github.com/hashicorp/vault/helper/parseutil"
"github.com/hashicorp/vault/helper/tlsutil"
"github.com/hashicorp/vault/plugins/helper/database/connutil"
)

// CassandraConnectionProducer implements ConnectionProducer and provides an
// cassandraConnectionProducer implements ConnectionProducer and provides an
// interface for cassandra databases to make connections.
type CassandraConnectionProducer struct {
type cassandraConnectionProducer struct {
Hosts string `json:"hosts" structs:"hosts" mapstructure:"hosts"`
Username string `json:"username" structs:"username" mapstructure:"username"`
Password string `json:"password" structs:"password" mapstructure:"password"`
Expand All @@ -41,15 +42,14 @@ type CassandraConnectionProducer struct {
sync.Mutex
}

func (c *CassandraConnectionProducer) Initialize(conf map[string]interface{}, verifyConnection bool) error {
func (c *cassandraConnectionProducer) Initialize(conf map[string]interface{}, verifyConnection bool) error {
c.Lock()
defer c.Unlock()

err := mapstructure.Decode(conf, c)
if err != nil {
return err
}
c.Initialized = true

if c.ConnectTimeoutRaw == nil {
c.ConnectTimeoutRaw = "0s"
Expand Down Expand Up @@ -100,17 +100,22 @@ func (c *CassandraConnectionProducer) Initialize(conf map[string]interface{}, ve
c.TLS = true
}

// Set initialized to true at this point since all fields are set,
// and the connection can be established at a later time.
c.Initialized = true

if verifyConnection {
if _, err := c.Connection(); err != nil {
return fmt.Errorf("error Initalizing Connection: %s", err)
return fmt.Errorf("error verifying connection: %s", err)
}
}

return nil
}

func (c *CassandraConnectionProducer) Connection() (interface{}, error) {
func (c *cassandraConnectionProducer) Connection() (interface{}, error) {
if !c.Initialized {
return nil, errNotInitialized
return nil, connutil.ErrNotInitialized
}

// If we already have a DB, return it
Expand All @@ -129,7 +134,7 @@ func (c *CassandraConnectionProducer) Connection() (interface{}, error) {
return session, nil
}

func (c *CassandraConnectionProducer) Close() error {
func (c *cassandraConnectionProducer) Close() error {
// Grab the write lock
c.Lock()
defer c.Unlock()
Expand All @@ -143,7 +148,7 @@ func (c *CassandraConnectionProducer) Close() error {
return nil
}

func (c *CassandraConnectionProducer) createSession() (*gocql.Session, error) {
func (c *cassandraConnectionProducer) createSession() (*gocql.Session, error) {
clusterConfig := gocql.NewCluster(strings.Split(c.Hosts, ",")...)
clusterConfig.Authenticator = gocql.PasswordAuthenticator{
Username: c.Username,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package credsutil
package cassandra

import (
"fmt"
Expand All @@ -8,11 +8,11 @@ import (
uuid "github.com/hashicorp/go-uuid"
)

// CassandraCredentialsProducer implements CredentialsProducer and provides an
// cassandraCredentialsProducer implements CredentialsProducer and provides an
// interface for cassandra databases to generate user information.
type CassandraCredentialsProducer struct{}
type cassandraCredentialsProducer struct{}

func (ccp *CassandraCredentialsProducer) GenerateUsername(displayName string) (string, error) {
func (ccp *cassandraCredentialsProducer) GenerateUsername(displayName string) (string, error) {
userUUID, err := uuid.GenerateUUID()
if err != nil {
return "", err
Expand All @@ -23,7 +23,7 @@ func (ccp *CassandraCredentialsProducer) GenerateUsername(displayName string) (s
return username, nil
}

func (ccp *CassandraCredentialsProducer) GeneratePassword() (string, error) {
func (ccp *cassandraCredentialsProducer) GeneratePassword() (string, error) {
password, err := uuid.GenerateUUID()
if err != nil {
return "", err
Expand All @@ -32,6 +32,6 @@ func (ccp *CassandraCredentialsProducer) GeneratePassword() (string, error) {
return password, nil
}

func (ccp *CassandraCredentialsProducer) GenerateExpiration(ttl time.Time) (string, error) {
func (ccp *cassandraCredentialsProducer) GenerateExpiration(ttl time.Time) (string, error) {
return "", nil
}
8 changes: 4 additions & 4 deletions plugins/database/cassandra/test-fixtures/cassandra.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -421,7 +421,7 @@ seed_provider:
parameters:
# seeds is actually a comma-delimited list of addresses.
# Ex: "<ip1>,<ip2>,<ip3>"
- seeds: "172.17.0.2"
- seeds: "172.17.0.4"

# For workloads with more data than can fit in memory, Cassandra's
# bottleneck will be reads that need to fetch data from
Expand Down Expand Up @@ -572,7 +572,7 @@ ssl_storage_port: 7001
#
# Setting listen_address to 0.0.0.0 is always wrong.
#
listen_address: 172.17.0.2
listen_address: 172.17.0.4

# Set listen_address OR listen_interface, not both. Interfaces must correspond
# to a single address, IP aliasing is not supported.
Expand All @@ -586,7 +586,7 @@ listen_address: 172.17.0.2

# Address to broadcast to other Cassandra nodes
# Leaving this blank will set it to the same value as listen_address
broadcast_address: 172.17.0.2
broadcast_address: 172.17.0.4

# When using multiple physical network interfaces, set this
# to true to listen on broadcast_address in addition to
Expand Down Expand Up @@ -668,7 +668,7 @@ rpc_port: 9160
# be set to 0.0.0.0. If left blank, this will be set to the value of
# rpc_address. If rpc_address is set to 0.0.0.0, broadcast_rpc_address must
# be set.
broadcast_rpc_address: 172.17.0.2
broadcast_rpc_address: 172.17.0.4

# enable or disable keepalive on rpc/native connections
rpc_keepalive: true
Expand Down
167 changes: 167 additions & 0 deletions plugins/database/mongodb/connection_producer.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
package mongodb

import (
"crypto/tls"
"errors"
"fmt"
"net"
"net/url"
"strconv"
"strings"
"sync"
"time"

"github.com/hashicorp/vault/plugins/helper/database/connutil"
"github.com/mitchellh/mapstructure"

"gopkg.in/mgo.v2"
)

// mongoDBConnectionProducer implements ConnectionProducer and provides an
// interface for databases to make connections.
type mongoDBConnectionProducer struct {
ConnectionURL string `json:"connection_url" structs:"connection_url" mapstructure:"connection_url"`

Initialized bool
Type string
session *mgo.Session
sync.Mutex
}

// Initialize parses connection configuration.
func (c *mongoDBConnectionProducer) Initialize(conf map[string]interface{}, verifyConnection bool) error {
c.Lock()
defer c.Unlock()

err := mapstructure.Decode(conf, c)
if err != nil {
return err
}

if len(c.ConnectionURL) == 0 {
return fmt.Errorf("connection_url cannot be empty")
}

// Set initialized to true at this point since all fields are set,
// and the connection can be established at a later time.
c.Initialized = true

if verifyConnection {
if _, err := c.Connection(); err != nil {
return fmt.Errorf("error verifying connection: %s", err)
}

if err := c.session.Ping(); err != nil {
return fmt.Errorf("error verifying connection: %s", err)
}
}

return nil
}

// Connection creates a database connection.
func (c *mongoDBConnectionProducer) Connection() (interface{}, error) {
if !c.Initialized {
return nil, connutil.ErrNotInitialized
}

if c.session != nil {
return c.session, nil
}

dialInfo, err := parseMongoURL(c.ConnectionURL)
if err != nil {
return nil, err
}

c.session, err = mgo.DialWithInfo(dialInfo)
if err != nil {
return nil, err
}
c.session.SetSyncTimeout(1 * time.Minute)
c.session.SetSocketTimeout(1 * time.Minute)

return nil, nil
}

// Close terminates the database connection.
func (c *mongoDBConnectionProducer) Close() error {
c.Lock()
defer c.Unlock()

if c.session != nil {
c.session.Close()
}

c.session = nil

return nil
}

func parseMongoURL(rawURL string) (*mgo.DialInfo, error) {
url, err := url.Parse(rawURL)
if err != nil {
return nil, err
}

info := mgo.DialInfo{
Addrs: strings.Split(url.Host, ","),
Database: strings.TrimPrefix(url.Path, "/"),
Timeout: 10 * time.Second,
}

if url.User != nil {
info.Username = url.User.Username()
info.Password, _ = url.User.Password()
}

query := url.Query()
for key, values := range query {
var value string
if len(values) > 0 {
value = values[0]
}

switch key {
case "authSource":
info.Source = value
case "authMechanism":
info.Mechanism = value
case "gssapiServiceName":
info.Service = value
case "replicaSet":
info.ReplicaSetName = value
case "maxPoolSize":
poolLimit, err := strconv.Atoi(value)
if err != nil {
return nil, errors.New("bad value for maxPoolSize: " + value)
}
info.PoolLimit = poolLimit
case "ssl":
// Unfortunately, mgo doesn't support the ssl parameter in its MongoDB URI parsing logic, so we have to handle that
// ourselves. See https://github.com/go-mgo/mgo/issues/84
ssl, err := strconv.ParseBool(value)
if err != nil {
return nil, errors.New("bad value for ssl: " + value)
}
if ssl {
info.DialServer = func(addr *mgo.ServerAddr) (net.Conn, error) {
return tls.Dial("tcp", addr.String(), &tls.Config{})
}
}
case "connect":
if value == "direct" {
info.Direct = true
break
}
if value == "replicaSet" {
break
}
fallthrough
default:
return nil, errors.New("unsupported connection URL option: " + key + "=" + value)
}
}

return &info, nil
}
36 changes: 36 additions & 0 deletions plugins/database/mongodb/credentials_producer.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
package mongodb

import (
"fmt"
"time"

uuid "github.com/hashicorp/go-uuid"
)

// mongoDBCredentialsProducer implements CredentialsProducer and provides an
// interface for databases to generate user information.
type mongoDBCredentialsProducer struct{}

func (cp *mongoDBCredentialsProducer) GenerateUsername(displayName string) (string, error) {
userUUID, err := uuid.GenerateUUID()
if err != nil {
return "", err
}

username := fmt.Sprintf("vault-%s-%s", displayName, userUUID)

return username, nil
}

func (cp *mongoDBCredentialsProducer) GeneratePassword() (string, error) {
password, err := uuid.GenerateUUID()
if err != nil {
return "", err
}

return password, nil
}

func (cp *mongoDBCredentialsProducer) GenerateExpiration(ttl time.Time) (string, error) {
return "", nil
}
Loading

0 comments on commit a4c652c

Please sign in to comment.