Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Avoid some pollution of CLI argument namespace across vtgate and vttablet #8931

Merged
merged 6 commits into from
Oct 19, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions go/cmd/vtgate/plugin_auth_vault.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,10 @@ package main
// This plugin imports InitAuthServerVault to register the HashiCorp Vault implementation of AuthServer.

import (
"vitess.io/vitess/go/mysql"
"vitess.io/vitess/go/mysql/vault"
"vitess.io/vitess/go/vt/vtgate"
)

func init() {
vtgate.RegisterPluginInitializer(func() { mysql.InitAuthServerVault() })
vtgate.RegisterPluginInitializer(func() { vault.InitAuthServerVault() })
}
3 changes: 1 addition & 2 deletions go/cmd/vtgate/vtgate.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@ import (
"vitess.io/vitess/go/vt/topo"
"vitess.io/vitess/go/vt/topo/topoproto"
"vitess.io/vitess/go/vt/vtgate"
"vitess.io/vitess/go/vt/vttablet/tabletserver"

topodatapb "vitess.io/vitess/go/vt/proto/topodata"
)
Expand Down Expand Up @@ -127,7 +126,7 @@ func main() {
log.Errorf("unknown tablet type: %v", ttStr)
continue
}
if tabletserver.IsServingType(tt) {
if topoproto.IsServingType(tt) {
tabletTypes = append(tabletTypes, tt)
}
}
Expand Down
4 changes: 2 additions & 2 deletions go/mysql/auth_server_clientcert.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ func (asl *AuthServerClientCert) UserEntryWithPassword(userCerts []*x509.Certifi
}

return &StaticUserData{
username: commonName,
groups: userCerts[0].DNSNames,
Username: commonName,
Groups: userCerts[0].DNSNames,
}, nil
}
22 changes: 12 additions & 10 deletions go/mysql/auth_server_static.go
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ func (a *AuthServerStatic) UserEntryWithPassword(userCerts []*x509.Certificate,

for _, entry := range entries {
// Validate the password.
if matchSourceHost(remoteAddr, entry.SourceHost) && subtle.ConstantTimeCompare([]byte(password), []byte(entry.Password)) == 1 {
if MatchSourceHost(remoteAddr, entry.SourceHost) && subtle.ConstantTimeCompare([]byte(password), []byte(entry.Password)) == 1 {
return &StaticUserData{entry.UserData, entry.Groups}, nil
}
}
Expand All @@ -198,13 +198,13 @@ func (a *AuthServerStatic) UserEntryWithHash(userCerts []*x509.Certificate, salt
}

isPass := VerifyHashedMysqlNativePassword(authResponse, salt, hash)
if matchSourceHost(remoteAddr, entry.SourceHost) && isPass {
if MatchSourceHost(remoteAddr, entry.SourceHost) && isPass {
return &StaticUserData{entry.UserData, entry.Groups}, nil
}
} else {
computedAuthResponse := ScrambleMysqlNativePassword(salt, []byte(entry.Password))
// Validate the password.
if matchSourceHost(remoteAddr, entry.SourceHost) && subtle.ConstantTimeCompare(authResponse, computedAuthResponse) == 1 {
if MatchSourceHost(remoteAddr, entry.SourceHost) && subtle.ConstantTimeCompare(authResponse, computedAuthResponse) == 1 {
return &StaticUserData{entry.UserData, entry.Groups}, nil
}
}
Expand All @@ -227,7 +227,7 @@ func (a *AuthServerStatic) UserEntryWithCacheHash(userCerts []*x509.Certificate,
computedAuthResponse := ScrambleCachingSha2Password(salt, []byte(entry.Password))

// Validate the password.
if matchSourceHost(remoteAddr, entry.SourceHost) && subtle.ConstantTimeCompare(authResponse, computedAuthResponse) == 1 {
if MatchSourceHost(remoteAddr, entry.SourceHost) && subtle.ConstantTimeCompare(authResponse, computedAuthResponse) == 1 {
return &StaticUserData{entry.UserData, entry.Groups}, AuthAccepted, nil
}
}
Expand Down Expand Up @@ -257,7 +257,7 @@ func (a *AuthServerStatic) reload() {
}

entries := make(map[string][]*AuthServerStaticEntry)
if err := parseConfig(jsonBytes, &entries); err != nil {
if err := ParseConfig(jsonBytes, &entries); err != nil {
log.Errorf("Error parsing auth server config: %v", err)
return
}
Expand Down Expand Up @@ -300,7 +300,8 @@ func (a *AuthServerStatic) close() {
}
}

func parseConfig(jsonBytes []byte, config *map[string][]*AuthServerStaticEntry) error {
// ParseConfig takes a JSON MySQL static config and converts to a validated map
func ParseConfig(jsonBytes []byte, config *map[string][]*AuthServerStaticEntry) error {
decoder := json.NewDecoder(bytes.NewReader(jsonBytes))
decoder.DisallowUnknownFields()
if err := decoder.Decode(config); err != nil {
Expand Down Expand Up @@ -336,7 +337,8 @@ func validateConfig(config map[string][]*AuthServerStaticEntry) error {
return nil
}

func matchSourceHost(remoteAddr net.Addr, targetSourceHost string) bool {
// MatchSourceHost validates host entry in auth configuration
func MatchSourceHost(remoteAddr net.Addr, targetSourceHost string) bool {
// Legacy support, there was not matcher defined default to true
if targetSourceHost == "" {
return true
Expand All @@ -352,11 +354,11 @@ func matchSourceHost(remoteAddr net.Addr, targetSourceHost string) bool {

// StaticUserData holds the username and groups
type StaticUserData struct {
username string
groups []string
Username string
Groups []string
}

// Get returns the wrapped username and groups
func (sud *StaticUserData) Get() *querypb.VTGateCallerID {
return &querypb.VTGateCallerID{Username: sud.username, Groups: sud.groups}
return &querypb.VTGateCallerID{Username: sud.Username, Groups: sud.Groups}
}
12 changes: 6 additions & 6 deletions go/mysql/auth_server_static_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ func TestJsonConfigParser(t *testing.T) {
// works with legacy format
config := make(map[string][]*AuthServerStaticEntry)
jsonConfig := "{\"mysql_user\":{\"Password\":\"123\", \"UserData\":\"dummy\"}, \"mysql_user_2\": {\"Password\": \"123\", \"UserData\": \"mysql_user_2\"}}"
err := parseConfig([]byte(jsonConfig), &config)
err := ParseConfig([]byte(jsonConfig), &config)
if err != nil {
t.Fatalf("should not get an error, but got: %v", err)
}
Expand All @@ -53,7 +53,7 @@ func TestJsonConfigParser(t *testing.T) {
{"Password": "123", "UserData": "mysql_user_all"},
{"Password": "456", "UserData": "mysql_user_with_groups", "Groups": ["user_group"]}
]}`
err = parseConfig([]byte(jsonConfig), &config)
err = ParseConfig([]byte(jsonConfig), &config)
if err != nil {
t.Fatalf("should not get an error, but got: %v", err)
}
Expand All @@ -72,7 +72,7 @@ func TestJsonConfigParser(t *testing.T) {
jsonConfig = `{
"mysql_user": [{"Password": "123", "UserData": "mysql_user_all", "InvalidKey": "oops"}]
}`
err = parseConfig([]byte(jsonConfig), &config)
err = ParseConfig([]byte(jsonConfig), &config)
if err == nil {
t.Fatalf("Invalid config should have errored, but didn't")
}
Expand Down Expand Up @@ -109,18 +109,18 @@ func TestValidateHashGetter(t *testing.T) {
func TestHostMatcher(t *testing.T) {
ip := net.ParseIP("192.168.0.1")
addr := &net.TCPAddr{IP: ip, Port: 9999}
match := matchSourceHost(net.Addr(addr), "")
match := MatchSourceHost(net.Addr(addr), "")
if !match {
t.Fatalf("Should match any address when target is empty")
}

match = matchSourceHost(net.Addr(addr), "localhost")
match = MatchSourceHost(net.Addr(addr), "localhost")
if match {
t.Fatalf("Should not match address when target is localhost")
}

socket := &net.UnixAddr{Name: "unixSocket", Net: "1"}
match = matchSourceHost(net.Addr(socket), "localhost")
match = MatchSourceHost(net.Addr(socket), "localhost")
if !match {
t.Fatalf("Should match socket when target is localhost")
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ See the License for the specific language governing permissions and
limitations under the License.
*/

package mysql
package vault

import (
"crypto/subtle"
Expand All @@ -31,6 +31,7 @@ import (

vaultapi "github.com/aquarapid/vaultlib"

"vitess.io/vitess/go/mysql"
"vitess.io/vitess/go/vt/log"
)

Expand All @@ -48,12 +49,12 @@ var (

// AuthServerVault implements AuthServer with a config loaded from Vault.
type AuthServerVault struct {
methods []AuthMethod
methods []mysql.AuthMethod
mu sync.Mutex
// users, passwords and user data
// We use the same JSON format as for -mysql_auth_server_static
// Acts as a cache for the in-Vault data
entries map[string][]*AuthServerStaticEntry
entries map[string][]*mysql.AuthServerStaticEntry
vaultCacheExpireTicker *time.Ticker
vaultClient *vaultapi.Client
vaultPath string
Expand Down Expand Up @@ -81,7 +82,7 @@ func registerAuthServerVault(addr string, timeout time.Duration, caCertPath stri
if err != nil {
log.Exitf("%s", err)
}
RegisterAuthServer("vault", authServerVault)
mysql.RegisterAuthServer("vault", authServerVault)
}

func newAuthServerVault(addr string, timeout time.Duration, caCertPath string, path string, ttl time.Duration, tokenFilePath string, roleID string, secretIDPath string, roleMountPoint string) (*AuthServerVault, error) {
Expand Down Expand Up @@ -135,11 +136,11 @@ func newAuthServerVault(addr string, timeout time.Duration, caCertPath string, p
vaultClient: client,
vaultPath: path,
vaultTTL: ttl,
entries: make(map[string][]*AuthServerStaticEntry),
entries: make(map[string][]*mysql.AuthServerStaticEntry),
}

authMethodNative := NewMysqlNativeAuthMethod(a, a)
a.methods = []AuthMethod{authMethodNative}
authMethodNative := mysql.NewMysqlNativeAuthMethod(a, a)
a.methods = []mysql.AuthMethod{authMethodNative}

a.reloadVault()
a.installSignalHandlers()
Expand All @@ -148,14 +149,14 @@ func newAuthServerVault(addr string, timeout time.Duration, caCertPath string, p

// AuthMethods returns the list of registered auth methods
// implemented by this auth server.
func (a *AuthServerVault) AuthMethods() []AuthMethod {
func (a *AuthServerVault) AuthMethods() []mysql.AuthMethod {
return a.methods
}

// DefaultAuthMethodDescription returns MysqlNativePassword as the default
// authentication method for the auth server implementation.
func (a *AuthServerVault) DefaultAuthMethodDescription() AuthMethodDescription {
return MysqlNativePassword
func (a *AuthServerVault) DefaultAuthMethodDescription() mysql.AuthMethodDescription {
return mysql.MysqlNativePassword
}

// HandleUser is part of the Validator interface. We
Expand All @@ -165,34 +166,34 @@ func (a *AuthServerVault) HandleUser(user string) bool {
}

// UserEntryWithHash is called when mysql_native_password is used.
func (a *AuthServerVault) UserEntryWithHash(userCerts []*x509.Certificate, salt []byte, user string, authResponse []byte, remoteAddr net.Addr) (Getter, error) {
func (a *AuthServerVault) UserEntryWithHash(userCerts []*x509.Certificate, salt []byte, user string, authResponse []byte, remoteAddr net.Addr) (mysql.Getter, error) {
a.mu.Lock()
userEntries, ok := a.entries[user]
a.mu.Unlock()

if !ok {
return &StaticUserData{}, NewSQLError(ERAccessDeniedError, SSAccessDeniedError, "Access denied for user '%v'", user)
return &mysql.StaticUserData{}, mysql.NewSQLError(mysql.ERAccessDeniedError, mysql.SSAccessDeniedError, "Access denied for user '%v'", user)
}

for _, entry := range userEntries {
if entry.MysqlNativePassword != "" {
hash, err := DecodeMysqlNativePasswordHex(entry.MysqlNativePassword)
hash, err := mysql.DecodeMysqlNativePasswordHex(entry.MysqlNativePassword)
if err != nil {
return &StaticUserData{entry.UserData, entry.Groups}, NewSQLError(ERAccessDeniedError, SSAccessDeniedError, "Access denied for user '%v'", user)
return &mysql.StaticUserData{Username: entry.UserData, Groups: entry.Groups}, mysql.NewSQLError(mysql.ERAccessDeniedError, mysql.SSAccessDeniedError, "Access denied for user '%v'", user)
}
isPass := VerifyHashedMysqlNativePassword(authResponse, salt, hash)
if matchSourceHost(remoteAddr, entry.SourceHost) && isPass {
return &StaticUserData{entry.UserData, entry.Groups}, nil
isPass := mysql.VerifyHashedMysqlNativePassword(authResponse, salt, hash)
if mysql.MatchSourceHost(remoteAddr, entry.SourceHost) && isPass {
return &mysql.StaticUserData{Username: entry.UserData, Groups: entry.Groups}, nil
}
} else {
computedAuthResponse := ScrambleMysqlNativePassword(salt, []byte(entry.Password))
computedAuthResponse := mysql.ScrambleMysqlNativePassword(salt, []byte(entry.Password))
// Validate the password.
if matchSourceHost(remoteAddr, entry.SourceHost) && subtle.ConstantTimeCompare(authResponse, computedAuthResponse) == 1 {
return &StaticUserData{entry.UserData, entry.Groups}, nil
if mysql.MatchSourceHost(remoteAddr, entry.SourceHost) && subtle.ConstantTimeCompare(authResponse, computedAuthResponse) == 1 {
return &mysql.StaticUserData{Username: entry.UserData, Groups: entry.Groups}, nil
}
}
}
return &StaticUserData{}, NewSQLError(ERAccessDeniedError, SSAccessDeniedError, "Access denied for user '%v'", user)
return &mysql.StaticUserData{}, mysql.NewSQLError(mysql.ERAccessDeniedError, mysql.SSAccessDeniedError, "Access denied for user '%v'", user)
}

func (a *AuthServerVault) setTTLTicker(ttl time.Duration) {
Expand Down Expand Up @@ -225,8 +226,8 @@ func (a *AuthServerVault) reloadVault() error {
return fmt.Errorf("Empty vtgate credentials retrieved from Vault server")
}

entries := make(map[string][]*AuthServerStaticEntry)
if err := parseConfig(secret.JSONSecret, &entries); err != nil {
entries := make(map[string][]*mysql.AuthServerStaticEntry)
if err := mysql.ParseConfig(secret.JSONSecret, &entries); err != nil {
return fmt.Errorf("Error parsing vtgate Vault auth server config: %v", err)
}
if len(entries) == 0 {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ see the license for the specific language governing permissions and
limitations under the license.
*/

package mysql
package vault

import (
"testing"
Expand Down
11 changes: 11 additions & 0 deletions go/vt/topo/topoproto/tablet.go
Original file line number Diff line number Diff line change
Expand Up @@ -253,3 +253,14 @@ func TabletDbName(tablet *topodatapb.Tablet) string {
func TabletIsAssigned(tablet *topodatapb.Tablet) bool {
return tablet != nil && tablet.Keyspace != "" && tablet.Shard != ""
}

// IsServingType returns true if the tablet type is one that should be serving to be healthy, or false if the tablet type
// should not be serving in it's healthy state.
func IsServingType(tabletType topodatapb.TabletType) bool {
switch tabletType {
case topodatapb.TabletType_PRIMARY, topodatapb.TabletType_REPLICA, topodatapb.TabletType_BATCH, topodatapb.TabletType_EXPERIMENTAL:
return true
default:
return false
}
}
14 changes: 2 additions & 12 deletions go/vt/vttablet/tabletserver/tabletserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ import (
"vitess.io/vitess/go/vt/srvtopo"
"vitess.io/vitess/go/vt/tableacl"
"vitess.io/vitess/go/vt/topo"
"vitess.io/vitess/go/vt/topo/topoproto"
"vitess.io/vitess/go/vt/vterrors"
"vitess.io/vitess/go/vt/vttablet/onlineddl"
"vitess.io/vitess/go/vt/vttablet/queryservice"
Expand Down Expand Up @@ -370,7 +371,7 @@ func (tsv *TabletServer) StopService() {
// connect to the database and serving traffic), or an error explaining
// the unhealthiness otherwise.
func (tsv *TabletServer) IsHealthy() error {
if IsServingType(tsv.sm.Target().TabletType) {
if topoproto.IsServingType(tsv.sm.Target().TabletType) {
_, err := tsv.Execute(
tabletenv.LocalContext(),
nil,
Expand All @@ -385,17 +386,6 @@ func (tsv *TabletServer) IsHealthy() error {
return nil
}

// IsServingType returns true if the tablet type is one that should be serving to be healthy, or false if the tablet type
// should not be serving in it's healthy state.
func IsServingType(tabletType topodatapb.TabletType) bool {
aquarapid marked this conversation as resolved.
Show resolved Hide resolved
switch tabletType {
case topodatapb.TabletType_PRIMARY, topodatapb.TabletType_REPLICA, topodatapb.TabletType_BATCH, topodatapb.TabletType_EXPERIMENTAL:
return true
default:
return false
}
}

// ReloadSchema reloads the schema.
func (tsv *TabletServer) ReloadSchema(ctx context.Context) error {
return tsv.se.Reload(ctx)
Expand Down