Skip to content

Commit

Permalink
Merge pull request #7265 from planetscale/jg_vault_9.0
Browse files Browse the repository at this point in the history
Cherry pick version of #7233 for release-9.0
  • Loading branch information
deepthi authored Jan 7, 2021
2 parents 0472d47 + 095ffcc commit 283497f
Show file tree
Hide file tree
Showing 20 changed files with 1,290 additions and 5 deletions.
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ require (
github.com/DataDog/datadog-go v2.2.0+incompatible
github.com/GeertJohan/go.rice v1.0.0
github.com/PuerkitoBio/goquery v1.5.1
github.com/aquarapid/vaultlib v0.5.1
github.com/armon/consul-api v0.0.0-20180202201655-eb2c6b5be1b6
github.com/armon/go-metrics v0.0.0-20190430140413-ec5e00d3c878
github.com/aws/aws-sdk-go v1.28.8
Expand Down
4 changes: 4 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,8 @@ github.com/alecthomas/units v0.0.0-20190717042225-c3de453c63f4/go.mod h1:ybxpYRF
github.com/andreyvit/diff v0.0.0-20170406064948-c7f18ee00883/go.mod h1:rCTlJbsFo29Kk6CurOXKm700vrz8f0KW0JNfpkRJY/8=
github.com/andybalholm/cascadia v1.1.0 h1:BuuO6sSfQNFRu1LppgbD25Hr2vLYW25JvxHs5zzsLTo=
github.com/andybalholm/cascadia v1.1.0/go.mod h1:GsXiBklL0woXo1j/WYWtSYYC4ouU9PqHO0sqidkEA4Y=
github.com/aquarapid/vaultlib v0.5.1 h1:vuLWR6bZzLHybjJBSUYPgZlIp6KZ+SXeHLRRYTuk6d4=
github.com/aquarapid/vaultlib v0.5.1/go.mod h1:yT7AlEXtuabkxylOc/+Ulyp18tff1+QjgNLTnFWTlOs=
github.com/armon/circbuf v0.0.0-20150827004946-bbbad097214e h1:QEF07wC0T1rKkctt1RINW/+RMTVmiwxETico2l3gxJA=
github.com/armon/circbuf v0.0.0-20150827004946-bbbad097214e/go.mod h1:3U/XgcO3hCbHZ8TKRvWD2dDTCfh9M9ya+I9JpbB7O8o=
github.com/armon/consul-api v0.0.0-20180202201655-eb2c6b5be1b6 h1:G1bPvciwNyF7IUmKXNt9Ak3m6u9DE1rF+RmtIkBpVdA=
Expand Down Expand Up @@ -472,6 +474,8 @@ github.com/mattn/go-sqlite3 v1.14.0 h1:mLyGNKR8+Vv9CAU7PphKa2hkEqxxhn8i32J6FPj1/
github.com/mattn/go-sqlite3 v1.14.0/go.mod h1:JIl7NbARA7phWnGvh0LKTyg7S9BA+6gx71ShQilpsus=
github.com/matttproud/golang_protobuf_extensions v1.0.1 h1:4hp9jkHxhMHkqkrB3Ix0jegS5sx/RkqARlsWZ6pIwiU=
github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5Ld7szi9bcBfOoFv/3dc6xSMkL2PC0=
github.com/mch1307/vaultlib v0.5.0 h1:+tI8YCG033aVI+kAKwo0fwrUylFs+wO6DB7DM5qXJzU=
github.com/mch1307/vaultlib v0.5.0/go.mod h1:phFbO1oIDL1xTqUrNXbrAG0VdcYEKP8TNa9FJd7hFic=
github.com/miekg/dns v1.0.14 h1:9jZdLNd/P4+SfEJ0TNyxYpsK8N4GtfylBLqtbYN1sbA=
github.com/miekg/dns v1.0.14/go.mod h1:W1PPwlIAgtquWBMBEV9nkV9Cazfe8ScdGz/Lj7v3Nrg=
github.com/miekg/dns v1.1.26 h1:gPxPSwALAeHJSjarOs00QjVdV9QoBvc1D2ujQUr5BzU=
Expand Down
28 changes: 28 additions & 0 deletions go/cmd/vtgate/plugin_auth_vault.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
/*
Copyright 2020 The Vitess Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreedto in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package main

// This plugin imports InitAuthServerVault to register the HashiCorp Vault implementation of AuthServer.

import (
"vitess.io/vitess/go/mysql"
"vitess.io/vitess/go/vt/vtgate"
)

func init() {
vtgate.RegisterPluginInitializer(func() { mysql.InitAuthServerVault() })
}
4 changes: 2 additions & 2 deletions go/cmd/vttablet/vttablet.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ func main() {
log.Exitf("failed to parse -tablet-path: %v", err)
}

// config and mycnf intializations are intertwined.
// config and mycnf initializations are intertwined.
config, mycnf := initConfig(tabletAlias)

ts := topo.Open()
Expand Down Expand Up @@ -106,7 +106,7 @@ func main() {
VREngine: vreplication.NewEngine(config, ts, tabletAlias.Cell, mysqld),
}
if err := tm.Start(tablet, config.Healthcheck.IntervalSeconds.Get()); err != nil {
log.Exitf("failed to parse -tablet-path: %v", err)
log.Exitf("failed to parse -tablet-path or initialize DB credentials: %v", err)
}
servenv.OnClose(func() {
// Close the tm so that our topo entry gets pruned properly and any
Expand Down
302 changes: 302 additions & 0 deletions go/mysql/auth_server_vault.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,302 @@
/*
Copyright 2020 The Vitess Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package mysql

import (
"bytes"
"flag"
"fmt"
"io/ioutil"
"net"
"os"
"os/signal"
"strings"
"sync"
"syscall"
"time"

vaultapi "github.com/aquarapid/vaultlib"

"vitess.io/vitess/go/vt/log"
)

var (
vaultAddr = flag.String("mysql_auth_vault_addr", "", "URL to Vault server")
vaultTimeout = flag.Duration("mysql_auth_vault_timeout", 10*time.Second, "Timeout for vault API operations")
vaultCACert = flag.String("mysql_auth_vault_tls_ca", "", "Path to CA PEM for validating Vault server certificate")
vaultPath = flag.String("mysql_auth_vault_path", "", "Vault path to vtgate credentials JSON blob, e.g.: secret/data/prod/vtgatecreds")
vaultCacheTTL = flag.Duration("mysql_auth_vault_ttl", 30*time.Minute, "How long to cache vtgate credentials from the Vault server")
vaultTokenFile = flag.String("mysql_auth_vault_tokenfile", "", "Path to file containing Vault auth token; token can also be passed using VAULT_TOKEN environment variable")
vaultRoleID = flag.String("mysql_auth_vault_roleid", "", "Vault AppRole id; can also be passed using VAULT_ROLEID environment variable")
vaultRoleSecretIDFile = flag.String("mysql_auth_vault_role_secretidfile", "", "Path to file containing Vault AppRole secret_id; can also be passed using VAULT_SECRETID environment variable")
vaultRoleMountPoint = flag.String("mysql_auth_vault_role_mountpoint", "approle", "Vault AppRole mountpoint; can also be passed using VAULT_MOUNTPOINT environment variable")
)

// AuthServerVault implements AuthServer with a config loaded from Vault.
type AuthServerVault struct {
mu sync.Mutex
// method can be set to:
// - MysqlNativePassword
// - MysqlClearPassword
// - MysqlDialog
// It defaults to MysqlNativePassword.
method string
// users, passwords and user data
// We use the same JSON format as for -mysql_auth_server_static
// Acts as a cache for the in-Vault data
entries map[string][]*AuthServerStaticEntry
vaultCacheExpireTicker *time.Ticker
vaultClient *vaultapi.Client
vaultPath string
vaultTTL time.Duration

sigChan chan os.Signal
}

// InitAuthServerVault - entrypoint for initialization of Vault AuthServer implementation
func InitAuthServerVault() {
// Check critical parameters.
if *vaultAddr == "" {
log.Infof("Not configuring AuthServerVault, as -mysql_auth_vault_addr is empty.")
return
}
if *vaultPath == "" {
log.Exitf("If using Vault auth server, -mysql_auth_vault_path is required.")
}

registerAuthServerVault(*vaultAddr, *vaultTimeout, *vaultCACert, *vaultPath, *vaultCacheTTL, *vaultTokenFile, *vaultRoleID, *vaultRoleSecretIDFile, *vaultRoleMountPoint)
}

func registerAuthServerVault(addr string, timeout time.Duration, caCertPath string, path string, ttl time.Duration, tokenFilePath string, roleID string, secretIDPath string, roleMountPoint string) {
authServerVault, err := newAuthServerVault(addr, timeout, caCertPath, path, ttl, tokenFilePath, roleID, secretIDPath, roleMountPoint)
if err != nil {
log.Exitf("%s", err)
}
RegisterAuthServerImpl("vault", authServerVault)
}

func newAuthServerVault(addr string, timeout time.Duration, caCertPath string, path string, ttl time.Duration, tokenFilePath string, roleID string, secretIDPath string, roleMountPoint string) (*AuthServerVault, error) {
// Validate more parameters
token, err := readFromFile(tokenFilePath)
if err != nil {
return nil, fmt.Errorf("No Vault token in provided filename for -mysql_auth_vault_tokenfile")
}
secretID, err := readFromFile(secretIDPath)
if err != nil {
return nil, fmt.Errorf("No Vault secret_id in provided filename for -mysql_auth_vault_role_secretidfile")
}

config := vaultapi.NewConfig()

// All these can be overriden by environment
// so we need to check if they have been set by NewConfig
if config.Address == "" {
config.Address = addr
}
if config.Timeout == (0 * time.Second) {
config.Timeout = timeout
}
if config.CACert == "" {
config.CACert = caCertPath
}
if config.Token == "" {
config.Token = token
}
if config.AppRoleCredentials.RoleID == "" {
config.AppRoleCredentials.RoleID = roleID
}
if config.AppRoleCredentials.SecretID == "" {
config.AppRoleCredentials.SecretID = secretID
}
if config.AppRoleCredentials.MountPoint == "" {
config.AppRoleCredentials.MountPoint = roleMountPoint
}

if config.CACert != "" {
// If we provide a CA, ensure we actually use it
config.InsecureSSL = false
}

client, err := vaultapi.NewClient(config)
if err != nil || client == nil {
log.Errorf("Error in vault client initialization, will retry: %v", err)
}

a := &AuthServerVault{
vaultClient: client,
vaultPath: path,
vaultTTL: ttl,
method: MysqlNativePassword,
entries: make(map[string][]*AuthServerStaticEntry),
}

a.reloadVault()
a.installSignalHandlers()
return a, nil
}

func (a *AuthServerVault) setTTLTicker(ttl time.Duration) {
a.mu.Lock()
defer a.mu.Unlock()
if a.vaultCacheExpireTicker == nil {
a.vaultCacheExpireTicker = time.NewTicker(ttl)
go func() {
for range a.vaultCacheExpireTicker.C {
a.sigChan <- syscall.SIGHUP
}
}()
} else {
a.vaultCacheExpireTicker.Reset(ttl)
}
}

// Reload JSON auth key from Vault. Return true if successful, false if not
func (a *AuthServerVault) reloadVault() error {
a.mu.Lock()
secret, err := a.vaultClient.GetSecret(a.vaultPath)
a.mu.Unlock()
a.setTTLTicker(10 * time.Second) // Reload frequently on error

if err != nil {
return fmt.Errorf("Error in vtgate Vault auth server params: %v", err)
}

if secret.JSONSecret == nil {
return fmt.Errorf("Empty vtgate credentials retrieved from Vault server")
}

entries := make(map[string][]*AuthServerStaticEntry)
if err := parseConfig(secret.JSONSecret, &entries); err != nil {
return fmt.Errorf("Error parsing vtgate Vault auth server config: %v", err)
}
if len(entries) == 0 {
return fmt.Errorf("vtgate credentials from Vault empty! Not updating previously cached values")
}

log.Infof("reloadVault(): success. Client status: %s", a.vaultClient.GetStatus())
a.mu.Lock()
a.entries = entries
a.mu.Unlock()
a.setTTLTicker(a.vaultTTL)
return nil
}

func (a *AuthServerVault) installSignalHandlers() {
a.mu.Lock()
defer a.mu.Unlock()

a.sigChan = make(chan os.Signal, 1)
signal.Notify(a.sigChan, syscall.SIGHUP)
go func() {
for range a.sigChan {
err := a.reloadVault()
if err != nil {
log.Errorf("%s", err)
}

}
}()
}

func (a *AuthServerVault) close() {
log.Warningf("Closing AuthServerVault instance.")
a.mu.Lock()
defer a.mu.Unlock()
if a.vaultCacheExpireTicker != nil {
a.vaultCacheExpireTicker.Stop()
}
if a.sigChan != nil {
signal.Stop(a.sigChan)
}
}

// AuthMethod is part of the AuthServer interface.
func (a *AuthServerVault) AuthMethod(user string) (string, error) {
return a.method, nil
}

// Salt is part of the AuthServer interface.
func (a *AuthServerVault) Salt() ([]byte, error) {
return NewSalt()
}

// ValidateHash is part of the AuthServer interface.
func (a *AuthServerVault) ValidateHash(salt []byte, user string, authResponse []byte, remoteAddr net.Addr) (Getter, error) {
a.mu.Lock()
userEntries, ok := a.entries[user]
a.mu.Unlock()

if !ok {
return &StaticUserData{}, NewSQLError(ERAccessDeniedError, SSAccessDeniedError, "Access denied for user '%v'", user)
}

for _, entry := range userEntries {
if entry.MysqlNativePassword != "" {
isPass := isPassScrambleMysqlNativePassword(authResponse, salt, entry.MysqlNativePassword)
if matchSourceHost(remoteAddr, entry.SourceHost) && isPass {
return &StaticUserData{entry.UserData, entry.Groups}, nil
}
} else {
computedAuthResponse := ScramblePassword(salt, []byte(entry.Password))
// Validate the password.
if matchSourceHost(remoteAddr, entry.SourceHost) && bytes.Equal(authResponse, computedAuthResponse) {
return &StaticUserData{entry.UserData, entry.Groups}, nil
}
}
}
return &StaticUserData{}, NewSQLError(ERAccessDeniedError, SSAccessDeniedError, "Access denied for user '%v'", user)
}

// Negotiate is part of the AuthServer interface.
// It will be called if method is anything else than MysqlNativePassword.
// We only recognize MysqlClearPassword and MysqlDialog here.
func (a *AuthServerVault) Negotiate(c *Conn, user string, remoteAddr net.Addr) (Getter, error) {
// Finish the negotiation.
password, err := AuthServerNegotiateClearOrDialog(c, a.method)
if err != nil {
return nil, err
}

a.mu.Lock()
userEntries, ok := a.entries[user]
a.mu.Unlock()

if !ok {
return &StaticUserData{}, NewSQLError(ERAccessDeniedError, SSAccessDeniedError, "Access denied for user '%v'", user)
}
for _, entry := range userEntries {
// Validate the password.
if matchSourceHost(remoteAddr, entry.SourceHost) && entry.Password == password {
return &StaticUserData{entry.UserData, entry.Groups}, nil
}
}
return &StaticUserData{}, NewSQLError(ERAccessDeniedError, SSAccessDeniedError, "Access denied for user '%v'", user)
}

// We ignore most errors here, to allow us to retry cleanly
// or ignore the cases where the input is not passed by file, but via env
func readFromFile(filePath string) (string, error) {
if filePath == "" {
return "", nil
}
fileBytes, err := ioutil.ReadFile(filePath)
if err != nil {
log.Errorf("Could not read file: %s", filePath)
return "", err
}
return strings.TrimSpace(string(fileBytes)), nil
}
Loading

0 comments on commit 283497f

Please sign in to comment.