diff --git a/go/mysql/auth_server_static.go b/go/mysql/auth_server_static.go index c6e3004544c..290de7a2b87 100644 --- a/go/mysql/auth_server_static.go +++ b/go/mysql/auth_server_static.go @@ -148,8 +148,7 @@ func (a *AuthServerStatic) installSignalHandlers() { sigChan := make(chan os.Signal, 1) signal.Notify(sigChan, syscall.SIGHUP) go func() { - for { - <-sigChan + for range sigChan { a.loadConfigFromParams(*mysqlAuthServerStaticFile, "") } }() diff --git a/go/vt/dbconfigs/credentials.go b/go/vt/dbconfigs/credentials.go index ab73e9d5695..3062e7a2151 100644 --- a/go/vt/dbconfigs/credentials.go +++ b/go/vt/dbconfigs/credentials.go @@ -26,7 +26,10 @@ import ( "errors" "flag" "io/ioutil" + "os" + "os/signal" "sync" + "syscall" "vitess.io/vitess/go/mysql" "vitess.io/vitess/go/vt/log" @@ -37,7 +40,7 @@ var ( dbCredentialsServer = flag.String("db-credentials-server", "file", "db credentials server type (use 'file' for the file implementation)") // 'file' implementation flags - dbCredentialsFile = flag.String("db-credentials-file", "", "db credentials file") + dbCredentialsFile = flag.String("db-credentials-file", "", "db credentials file; send SIGHUP to reload this file") // ErrUnknownUser is returned by credential server when the // user doesn't exist @@ -126,4 +129,16 @@ func WithCredentials(cp *mysql.ConnParams) (*mysql.ConnParams, error) { func init() { AllCredentialsServers["file"] = &FileCredentialsServer{} + + sigChan := make(chan os.Signal, 1) + signal.Notify(sigChan, syscall.SIGHUP) + go func() { + for range sigChan { + if fcs, ok := AllCredentialsServers["file"].(*FileCredentialsServer); ok { + fcs.mu.Lock() + fcs.dbCredentials = nil + fcs.mu.Unlock() + } + } + }() } diff --git a/go/vt/dbconfigs/dbconfigs_test.go b/go/vt/dbconfigs/dbconfigs_test.go index 0f6e87fd4c5..86fbd757581 100644 --- a/go/vt/dbconfigs/dbconfigs_test.go +++ b/go/vt/dbconfigs/dbconfigs_test.go @@ -17,8 +17,13 @@ limitations under the License. package dbconfigs import ( + "fmt" + "io/ioutil" + "os" "reflect" + "syscall" "testing" + "time" "vitess.io/vitess/go/mysql" ) @@ -217,3 +222,47 @@ func TestCopy(t *testing.T) { t.Errorf("DBConfig: %v, want %v", got, want) } } + +func TestCredentialsFileHUP(t *testing.T) { + tmpFile, err := ioutil.TempFile("", "credentials.json") + if err != nil { + t.Fatalf("couldn't create temp file: %v", err) + } + defer os.Remove(tmpFile.Name()) + *dbCredentialsFile = tmpFile.Name() + *dbCredentialsServer = "file" + oldStr := "str1" + jsonConfig := fmt.Sprintf("{\"%s\": [\"%s\"]}", oldStr, oldStr) + if err := ioutil.WriteFile(tmpFile.Name(), []byte(jsonConfig), 0600); err != nil { + t.Fatalf("couldn't write temp file: %v", err) + } + cs := GetCredentialsServer() + _, pass, err := cs.GetUserAndPassword(oldStr) + if pass != oldStr { + t.Fatalf("%s's Password should still be '%s'", oldStr, oldStr) + } + hupTest(t, tmpFile, oldStr, "str2") + hupTest(t, tmpFile, "str2", "str3") // still handling the signal +} + +func hupTest(t *testing.T, tmpFile *os.File, oldStr, newStr string) { + cs := GetCredentialsServer() + jsonConfig := fmt.Sprintf("{\"%s\": [\"%s\"]}", newStr, newStr) + if err := ioutil.WriteFile(tmpFile.Name(), []byte(jsonConfig), 0600); err != nil { + t.Fatalf("couldn't overwrite temp file: %v", err) + } + _, pass, err := cs.GetUserAndPassword(oldStr) + if pass != oldStr { + t.Fatalf("%s's Password should still be '%s'", oldStr, oldStr) + } + syscall.Kill(syscall.Getpid(), syscall.SIGHUP) + time.Sleep(100 * time.Millisecond) // wait for signal handler + _, pass, err = cs.GetUserAndPassword(oldStr) + if err != ErrUnknownUser { + t.Fatalf("Should not have old %s after config reload", oldStr) + } + _, pass, err = cs.GetUserAndPassword(newStr) + if pass != newStr { + t.Fatalf("%s's Password should be '%s'", newStr, newStr) + } +} diff --git a/go/vt/vttablet/tabletserver/tabletserver.go b/go/vt/vttablet/tabletserver/tabletserver.go index e698969816e..3f25c1bdeea 100644 --- a/go/vt/vttablet/tabletserver/tabletserver.go +++ b/go/vt/vttablet/tabletserver/tabletserver.go @@ -374,8 +374,7 @@ func (tsv *TabletServer) InitACL(tableACLConfigFile string, enforceTableACLConfi sigChan := make(chan os.Signal, 1) signal.Notify(sigChan, syscall.SIGHUP) go func() { - for { - <-sigChan + for range sigChan { tsv.initACL(tableACLConfigFile, enforceTableACLConfig) } }()