Skip to content

Commit

Permalink
Merge pull request #8931 from planetscale/jg_flags
Browse files Browse the repository at this point in the history
Avoid some pollution of CLI argument namespace across vtgate and vttablet
  • Loading branch information
deepthi authored Oct 19, 2021
2 parents 0f1ee35 + 3743339 commit 5fee4d0
Show file tree
Hide file tree
Showing 9 changed files with 61 additions and 58 deletions.
4 changes: 2 additions & 2 deletions go/cmd/vtgate/plugin_auth_vault.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,10 @@ package main
// This plugin imports InitAuthServerVault to register the HashiCorp Vault implementation of AuthServer.

import (
"vitess.io/vitess/go/mysql"
"vitess.io/vitess/go/mysql/vault"
"vitess.io/vitess/go/vt/vtgate"
)

func init() {
vtgate.RegisterPluginInitializer(func() { mysql.InitAuthServerVault() })
vtgate.RegisterPluginInitializer(func() { vault.InitAuthServerVault() })
}
3 changes: 1 addition & 2 deletions go/cmd/vtgate/vtgate.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@ import (
"vitess.io/vitess/go/vt/topo"
"vitess.io/vitess/go/vt/topo/topoproto"
"vitess.io/vitess/go/vt/vtgate"
"vitess.io/vitess/go/vt/vttablet/tabletserver"

topodatapb "vitess.io/vitess/go/vt/proto/topodata"
)
Expand Down Expand Up @@ -127,7 +126,7 @@ func main() {
log.Errorf("unknown tablet type: %v", ttStr)
continue
}
if tabletserver.IsServingType(tt) {
if topoproto.IsServingType(tt) {
tabletTypes = append(tabletTypes, tt)
}
}
Expand Down
4 changes: 2 additions & 2 deletions go/mysql/auth_server_clientcert.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ func (asl *AuthServerClientCert) UserEntryWithPassword(userCerts []*x509.Certifi
}

return &StaticUserData{
username: commonName,
groups: userCerts[0].DNSNames,
Username: commonName,
Groups: userCerts[0].DNSNames,
}, nil
}
22 changes: 12 additions & 10 deletions go/mysql/auth_server_static.go
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ func (a *AuthServerStatic) UserEntryWithPassword(userCerts []*x509.Certificate,

for _, entry := range entries {
// Validate the password.
if matchSourceHost(remoteAddr, entry.SourceHost) && subtle.ConstantTimeCompare([]byte(password), []byte(entry.Password)) == 1 {
if MatchSourceHost(remoteAddr, entry.SourceHost) && subtle.ConstantTimeCompare([]byte(password), []byte(entry.Password)) == 1 {
return &StaticUserData{entry.UserData, entry.Groups}, nil
}
}
Expand All @@ -198,13 +198,13 @@ func (a *AuthServerStatic) UserEntryWithHash(userCerts []*x509.Certificate, salt
}

isPass := VerifyHashedMysqlNativePassword(authResponse, salt, hash)
if matchSourceHost(remoteAddr, entry.SourceHost) && isPass {
if MatchSourceHost(remoteAddr, entry.SourceHost) && isPass {
return &StaticUserData{entry.UserData, entry.Groups}, nil
}
} else {
computedAuthResponse := ScrambleMysqlNativePassword(salt, []byte(entry.Password))
// Validate the password.
if matchSourceHost(remoteAddr, entry.SourceHost) && subtle.ConstantTimeCompare(authResponse, computedAuthResponse) == 1 {
if MatchSourceHost(remoteAddr, entry.SourceHost) && subtle.ConstantTimeCompare(authResponse, computedAuthResponse) == 1 {
return &StaticUserData{entry.UserData, entry.Groups}, nil
}
}
Expand All @@ -227,7 +227,7 @@ func (a *AuthServerStatic) UserEntryWithCacheHash(userCerts []*x509.Certificate,
computedAuthResponse := ScrambleCachingSha2Password(salt, []byte(entry.Password))

// Validate the password.
if matchSourceHost(remoteAddr, entry.SourceHost) && subtle.ConstantTimeCompare(authResponse, computedAuthResponse) == 1 {
if MatchSourceHost(remoteAddr, entry.SourceHost) && subtle.ConstantTimeCompare(authResponse, computedAuthResponse) == 1 {
return &StaticUserData{entry.UserData, entry.Groups}, AuthAccepted, nil
}
}
Expand Down Expand Up @@ -257,7 +257,7 @@ func (a *AuthServerStatic) reload() {
}

entries := make(map[string][]*AuthServerStaticEntry)
if err := parseConfig(jsonBytes, &entries); err != nil {
if err := ParseConfig(jsonBytes, &entries); err != nil {
log.Errorf("Error parsing auth server config: %v", err)
return
}
Expand Down Expand Up @@ -300,7 +300,8 @@ func (a *AuthServerStatic) close() {
}
}

func parseConfig(jsonBytes []byte, config *map[string][]*AuthServerStaticEntry) error {
// ParseConfig takes a JSON MySQL static config and converts to a validated map
func ParseConfig(jsonBytes []byte, config *map[string][]*AuthServerStaticEntry) error {
decoder := json.NewDecoder(bytes.NewReader(jsonBytes))
decoder.DisallowUnknownFields()
if err := decoder.Decode(config); err != nil {
Expand Down Expand Up @@ -336,7 +337,8 @@ func validateConfig(config map[string][]*AuthServerStaticEntry) error {
return nil
}

func matchSourceHost(remoteAddr net.Addr, targetSourceHost string) bool {
// MatchSourceHost validates host entry in auth configuration
func MatchSourceHost(remoteAddr net.Addr, targetSourceHost string) bool {
// Legacy support, there was not matcher defined default to true
if targetSourceHost == "" {
return true
Expand All @@ -352,11 +354,11 @@ func matchSourceHost(remoteAddr net.Addr, targetSourceHost string) bool {

// StaticUserData holds the username and groups
type StaticUserData struct {
username string
groups []string
Username string
Groups []string
}

// Get returns the wrapped username and groups
func (sud *StaticUserData) Get() *querypb.VTGateCallerID {
return &querypb.VTGateCallerID{Username: sud.username, Groups: sud.groups}
return &querypb.VTGateCallerID{Username: sud.Username, Groups: sud.Groups}
}
12 changes: 6 additions & 6 deletions go/mysql/auth_server_static_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ func TestJsonConfigParser(t *testing.T) {
// works with legacy format
config := make(map[string][]*AuthServerStaticEntry)
jsonConfig := "{\"mysql_user\":{\"Password\":\"123\", \"UserData\":\"dummy\"}, \"mysql_user_2\": {\"Password\": \"123\", \"UserData\": \"mysql_user_2\"}}"
err := parseConfig([]byte(jsonConfig), &config)
err := ParseConfig([]byte(jsonConfig), &config)
if err != nil {
t.Fatalf("should not get an error, but got: %v", err)
}
Expand All @@ -53,7 +53,7 @@ func TestJsonConfigParser(t *testing.T) {
{"Password": "123", "UserData": "mysql_user_all"},
{"Password": "456", "UserData": "mysql_user_with_groups", "Groups": ["user_group"]}
]}`
err = parseConfig([]byte(jsonConfig), &config)
err = ParseConfig([]byte(jsonConfig), &config)
if err != nil {
t.Fatalf("should not get an error, but got: %v", err)
}
Expand All @@ -72,7 +72,7 @@ func TestJsonConfigParser(t *testing.T) {
jsonConfig = `{
"mysql_user": [{"Password": "123", "UserData": "mysql_user_all", "InvalidKey": "oops"}]
}`
err = parseConfig([]byte(jsonConfig), &config)
err = ParseConfig([]byte(jsonConfig), &config)
if err == nil {
t.Fatalf("Invalid config should have errored, but didn't")
}
Expand Down Expand Up @@ -109,18 +109,18 @@ func TestValidateHashGetter(t *testing.T) {
func TestHostMatcher(t *testing.T) {
ip := net.ParseIP("192.168.0.1")
addr := &net.TCPAddr{IP: ip, Port: 9999}
match := matchSourceHost(net.Addr(addr), "")
match := MatchSourceHost(net.Addr(addr), "")
if !match {
t.Fatalf("Should match any address when target is empty")
}

match = matchSourceHost(net.Addr(addr), "localhost")
match = MatchSourceHost(net.Addr(addr), "localhost")
if match {
t.Fatalf("Should not match address when target is localhost")
}

socket := &net.UnixAddr{Name: "unixSocket", Net: "1"}
match = matchSourceHost(net.Addr(socket), "localhost")
match = MatchSourceHost(net.Addr(socket), "localhost")
if !match {
t.Fatalf("Should match socket when target is localhost")
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ See the License for the specific language governing permissions and
limitations under the License.
*/

package mysql
package vault

import (
"crypto/subtle"
Expand All @@ -31,6 +31,7 @@ import (

vaultapi "github.com/aquarapid/vaultlib"

"vitess.io/vitess/go/mysql"
"vitess.io/vitess/go/vt/log"
)

Expand All @@ -48,12 +49,12 @@ var (

// AuthServerVault implements AuthServer with a config loaded from Vault.
type AuthServerVault struct {
methods []AuthMethod
methods []mysql.AuthMethod
mu sync.Mutex
// users, passwords and user data
// We use the same JSON format as for -mysql_auth_server_static
// Acts as a cache for the in-Vault data
entries map[string][]*AuthServerStaticEntry
entries map[string][]*mysql.AuthServerStaticEntry
vaultCacheExpireTicker *time.Ticker
vaultClient *vaultapi.Client
vaultPath string
Expand Down Expand Up @@ -81,7 +82,7 @@ func registerAuthServerVault(addr string, timeout time.Duration, caCertPath stri
if err != nil {
log.Exitf("%s", err)
}
RegisterAuthServer("vault", authServerVault)
mysql.RegisterAuthServer("vault", authServerVault)
}

func newAuthServerVault(addr string, timeout time.Duration, caCertPath string, path string, ttl time.Duration, tokenFilePath string, roleID string, secretIDPath string, roleMountPoint string) (*AuthServerVault, error) {
Expand Down Expand Up @@ -135,11 +136,11 @@ func newAuthServerVault(addr string, timeout time.Duration, caCertPath string, p
vaultClient: client,
vaultPath: path,
vaultTTL: ttl,
entries: make(map[string][]*AuthServerStaticEntry),
entries: make(map[string][]*mysql.AuthServerStaticEntry),
}

authMethodNative := NewMysqlNativeAuthMethod(a, a)
a.methods = []AuthMethod{authMethodNative}
authMethodNative := mysql.NewMysqlNativeAuthMethod(a, a)
a.methods = []mysql.AuthMethod{authMethodNative}

a.reloadVault()
a.installSignalHandlers()
Expand All @@ -148,14 +149,14 @@ func newAuthServerVault(addr string, timeout time.Duration, caCertPath string, p

// AuthMethods returns the list of registered auth methods
// implemented by this auth server.
func (a *AuthServerVault) AuthMethods() []AuthMethod {
func (a *AuthServerVault) AuthMethods() []mysql.AuthMethod {
return a.methods
}

// DefaultAuthMethodDescription returns MysqlNativePassword as the default
// authentication method for the auth server implementation.
func (a *AuthServerVault) DefaultAuthMethodDescription() AuthMethodDescription {
return MysqlNativePassword
func (a *AuthServerVault) DefaultAuthMethodDescription() mysql.AuthMethodDescription {
return mysql.MysqlNativePassword
}

// HandleUser is part of the Validator interface. We
Expand All @@ -165,34 +166,34 @@ func (a *AuthServerVault) HandleUser(user string) bool {
}

// UserEntryWithHash is called when mysql_native_password is used.
func (a *AuthServerVault) UserEntryWithHash(userCerts []*x509.Certificate, salt []byte, user string, authResponse []byte, remoteAddr net.Addr) (Getter, error) {
func (a *AuthServerVault) UserEntryWithHash(userCerts []*x509.Certificate, salt []byte, user string, authResponse []byte, remoteAddr net.Addr) (mysql.Getter, error) {
a.mu.Lock()
userEntries, ok := a.entries[user]
a.mu.Unlock()

if !ok {
return &StaticUserData{}, NewSQLError(ERAccessDeniedError, SSAccessDeniedError, "Access denied for user '%v'", user)
return &mysql.StaticUserData{}, mysql.NewSQLError(mysql.ERAccessDeniedError, mysql.SSAccessDeniedError, "Access denied for user '%v'", user)
}

for _, entry := range userEntries {
if entry.MysqlNativePassword != "" {
hash, err := DecodeMysqlNativePasswordHex(entry.MysqlNativePassword)
hash, err := mysql.DecodeMysqlNativePasswordHex(entry.MysqlNativePassword)
if err != nil {
return &StaticUserData{entry.UserData, entry.Groups}, NewSQLError(ERAccessDeniedError, SSAccessDeniedError, "Access denied for user '%v'", user)
return &mysql.StaticUserData{Username: entry.UserData, Groups: entry.Groups}, mysql.NewSQLError(mysql.ERAccessDeniedError, mysql.SSAccessDeniedError, "Access denied for user '%v'", user)
}
isPass := VerifyHashedMysqlNativePassword(authResponse, salt, hash)
if matchSourceHost(remoteAddr, entry.SourceHost) && isPass {
return &StaticUserData{entry.UserData, entry.Groups}, nil
isPass := mysql.VerifyHashedMysqlNativePassword(authResponse, salt, hash)
if mysql.MatchSourceHost(remoteAddr, entry.SourceHost) && isPass {
return &mysql.StaticUserData{Username: entry.UserData, Groups: entry.Groups}, nil
}
} else {
computedAuthResponse := ScrambleMysqlNativePassword(salt, []byte(entry.Password))
computedAuthResponse := mysql.ScrambleMysqlNativePassword(salt, []byte(entry.Password))
// Validate the password.
if matchSourceHost(remoteAddr, entry.SourceHost) && subtle.ConstantTimeCompare(authResponse, computedAuthResponse) == 1 {
return &StaticUserData{entry.UserData, entry.Groups}, nil
if mysql.MatchSourceHost(remoteAddr, entry.SourceHost) && subtle.ConstantTimeCompare(authResponse, computedAuthResponse) == 1 {
return &mysql.StaticUserData{Username: entry.UserData, Groups: entry.Groups}, nil
}
}
}
return &StaticUserData{}, NewSQLError(ERAccessDeniedError, SSAccessDeniedError, "Access denied for user '%v'", user)
return &mysql.StaticUserData{}, mysql.NewSQLError(mysql.ERAccessDeniedError, mysql.SSAccessDeniedError, "Access denied for user '%v'", user)
}

func (a *AuthServerVault) setTTLTicker(ttl time.Duration) {
Expand Down Expand Up @@ -225,8 +226,8 @@ func (a *AuthServerVault) reloadVault() error {
return fmt.Errorf("Empty vtgate credentials retrieved from Vault server")
}

entries := make(map[string][]*AuthServerStaticEntry)
if err := parseConfig(secret.JSONSecret, &entries); err != nil {
entries := make(map[string][]*mysql.AuthServerStaticEntry)
if err := mysql.ParseConfig(secret.JSONSecret, &entries); err != nil {
return fmt.Errorf("Error parsing vtgate Vault auth server config: %v", err)
}
if len(entries) == 0 {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ see the license for the specific language governing permissions and
limitations under the license.
*/

package mysql
package vault

import (
"testing"
Expand Down
11 changes: 11 additions & 0 deletions go/vt/topo/topoproto/tablet.go
Original file line number Diff line number Diff line change
Expand Up @@ -253,3 +253,14 @@ func TabletDbName(tablet *topodatapb.Tablet) string {
func TabletIsAssigned(tablet *topodatapb.Tablet) bool {
return tablet != nil && tablet.Keyspace != "" && tablet.Shard != ""
}

// IsServingType returns true if the tablet type is one that should be serving to be healthy, or false if the tablet type
// should not be serving in it's healthy state.
func IsServingType(tabletType topodatapb.TabletType) bool {
switch tabletType {
case topodatapb.TabletType_PRIMARY, topodatapb.TabletType_REPLICA, topodatapb.TabletType_BATCH, topodatapb.TabletType_EXPERIMENTAL:
return true
default:
return false
}
}
14 changes: 2 additions & 12 deletions go/vt/vttablet/tabletserver/tabletserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ import (
"vitess.io/vitess/go/vt/srvtopo"
"vitess.io/vitess/go/vt/tableacl"
"vitess.io/vitess/go/vt/topo"
"vitess.io/vitess/go/vt/topo/topoproto"
"vitess.io/vitess/go/vt/vterrors"
"vitess.io/vitess/go/vt/vttablet/onlineddl"
"vitess.io/vitess/go/vt/vttablet/queryservice"
Expand Down Expand Up @@ -370,7 +371,7 @@ func (tsv *TabletServer) StopService() {
// connect to the database and serving traffic), or an error explaining
// the unhealthiness otherwise.
func (tsv *TabletServer) IsHealthy() error {
if IsServingType(tsv.sm.Target().TabletType) {
if topoproto.IsServingType(tsv.sm.Target().TabletType) {
_, err := tsv.Execute(
tabletenv.LocalContext(),
nil,
Expand All @@ -385,17 +386,6 @@ func (tsv *TabletServer) IsHealthy() error {
return nil
}

// IsServingType returns true if the tablet type is one that should be serving to be healthy, or false if the tablet type
// should not be serving in it's healthy state.
func IsServingType(tabletType topodatapb.TabletType) bool {
switch tabletType {
case topodatapb.TabletType_PRIMARY, topodatapb.TabletType_REPLICA, topodatapb.TabletType_BATCH, topodatapb.TabletType_EXPERIMENTAL:
return true
default:
return false
}
}

// ReloadSchema reloads the schema.
func (tsv *TabletServer) ReloadSchema(ctx context.Context) error {
return tsv.se.Reload(ctx)
Expand Down

0 comments on commit 5fee4d0

Please sign in to comment.