diff --git a/cluster/cluster.go b/cluster/cluster.go index 30a4927dc..c74b83e48 100644 --- a/cluster/cluster.go +++ b/cluster/cluster.go @@ -28,6 +28,8 @@ import ( metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/labels" "k8s.io/apimachinery/pkg/runtime" + "k8s.io/apimachinery/pkg/types" + "sigs.k8s.io/controller-runtime/pkg/client" "sigs.k8s.io/controller-runtime/pkg/controller/controllerutil" apiv1alpha1 "github.com/radondb/radondb-mysql-kubernetes/api/v1alpha1" @@ -331,3 +333,28 @@ func sizeToBytes(s string) (uint64, error) { } return 0, fmt.Errorf("'%s' format error, must be a positive integer with a unit of measurement like K, M or G", s) } + +// IsMysqlClusterKind for the given kind checks if CRD kind is for MysqlCluster CRD. +func IsMysqlClusterKind(kind string) bool { + switch kind { + case "MysqlCluster", "mysqlcluster", "mysqlclusters": + return true + } + return false +} + +// GetClusterKey returns the MysqlUser's MySQLCluster key. +func (c *Cluster) GetClusterKey() client.ObjectKey { + return client.ObjectKey{ + Name: c.Name, + Namespace: c.Namespace, + } +} + +// GetKey return the user key. Usually used for logging or for runtime.Client.Get as key. +func (c *Cluster) GetKey() client.ObjectKey { + return types.NamespacedName{ + Namespace: c.Namespace, + Name: c.Name, + } +} diff --git a/cluster/syncer/statefulset.go b/cluster/syncer/statefulset.go index 23a98f171..e4c45e03d 100644 --- a/cluster/syncer/statefulset.go +++ b/cluster/syncer/statefulset.go @@ -63,10 +63,13 @@ type StatefulSetSyncer struct { // Secret resourceVersion. sctRev string + + // mysql query runner + internal.SQLRunnerFactory } // NewStatefulSetSyncer returns a pointer to StatefulSetSyncer. -func NewStatefulSetSyncer(cli client.Client, c *cluster.Cluster, cmRev, sctRev string) *StatefulSetSyncer { +func NewStatefulSetSyncer(cli client.Client, c *cluster.Cluster, cmRev, sctRev string, sqlRunnerFactory internal.SQLRunnerFactory) *StatefulSetSyncer { return &StatefulSetSyncer{ Cluster: c, cli: cli, @@ -80,8 +83,9 @@ func NewStatefulSetSyncer(cli client.Client, c *cluster.Cluster, cmRev, sctRev s Namespace: c.Namespace, }, }, - cmRev: cmRev, - sctRev: sctRev, + cmRev: cmRev, + sctRev: sctRev, + SQLRunnerFactory: sqlRunnerFactory, } } @@ -258,6 +262,12 @@ func (s *StatefulSetSyncer) updatePod(ctx context.Context) error { // 5. Check followerHost current role. // 6. If followerHost is not leader, switch it to leader through xenon. func (s *StatefulSetSyncer) preUpdate(ctx context.Context, leader, follower string) error { + sqlRunner, closeConn, err := s.SQLRunnerFactory(internal.NewConfigFromClusterKey(s.cli, s.Cluster.GetClusterKey(), utils.OperatorUser, string(utils.LeaderHost))) + if err != nil { + return err + } + defer closeConn() + // Status.Replicas indicate the number of Pod has been created. // So sfs.Spec.Replicas is 2, May be sfs.Status.Replicas maybe are 3, 5 , // because it do not update the pods, so it is still the last status. @@ -272,7 +282,6 @@ func (s *StatefulSetSyncer) preUpdate(ctx context.Context, leader, follower stri defer utils.RemoveUpdateFile() sctName := s.GetNameForResource(utils.Secret) svcName := s.GetNameForResource(utils.HeadlessSVC) - port := utils.MysqlPort nameSpace := s.Namespace // Get secrets. @@ -286,36 +295,21 @@ func (s *StatefulSetSyncer) preUpdate(ctx context.Context, leader, follower stri ); err != nil { return fmt.Errorf("failed to get the secret: %s", sctName) } - user, ok := secret.Data["operator-user"] - if !ok { - return fmt.Errorf("failed to get the user: %s", user) - } - password, ok := secret.Data["operator-password"] - if !ok { - return fmt.Errorf("failed to get the password: %s", password) - } + rootPasswd, ok := secret.Data["root-password"] if !ok { return fmt.Errorf("failed to get the root password: %s", rootPasswd) } - leaderHost := fmt.Sprintf("%s.%s.%s", leader, svcName, nameSpace) - leaderRunner, err := internal.NewSQLRunner(utils.BytesToString(user), utils.BytesToString(password), leaderHost, port) - if err != nil { - log.Error(err, "failed to connect the mysql", "node", leader) - return err - } - defer leaderRunner.Close() - if err = retry(time.Second*2, time.Duration(waitLimit)*time.Second, func() (bool, error) { // Set leader read only. - if err = leaderRunner.RunQuery("SET GLOBAL super_read_only=on;"); err != nil { + if err = sqlRunner.RunQuery("SET GLOBAL super_read_only=on;"); err != nil { log.Error(err, "failed to set leader read only", "node", leader) return false, err } // Make sure the master has sent all binlog to slave. - success, err := leaderRunner.CheckProcesslist() + success, err := sqlRunner.CheckProcesslist() if err != nil { return false, err } diff --git a/cluster/syncer/status.go b/cluster/syncer/status.go index ee20a3172..f7f6aeb6a 100644 --- a/cluster/syncer/status.go +++ b/cluster/syncer/status.go @@ -48,13 +48,17 @@ type StatusSyncer struct { *cluster.Cluster cli client.Client + + // mysql query runner + internal.SQLRunnerFactory } // NewStatusSyncer returns a pointer to StatusSyncer. -func NewStatusSyncer(c *cluster.Cluster, cli client.Client) *StatusSyncer { +func NewStatusSyncer(c *cluster.Cluster, cli client.Client, sqlRunnerFactory internal.SQLRunnerFactory) *StatusSyncer { return &StatusSyncer{ - Cluster: c, - cli: cli, + Cluster: c, + cli: cli, + SQLRunnerFactory: sqlRunnerFactory, } } @@ -144,7 +148,6 @@ func (s *StatusSyncer) Sync(ctx context.Context) (syncer.SyncResult, error) { func (s *StatusSyncer) updateNodeStatus(ctx context.Context, cli client.Client, pods []corev1.Pod) error { sctName := s.GetNameForResource(utils.Secret) svcName := s.GetNameForResource(utils.HeadlessSVC) - port := utils.MysqlPort nameSpace := s.Namespace secret := &corev1.Secret{} @@ -158,14 +161,7 @@ func (s *StatusSyncer) updateNodeStatus(ctx context.Context, cli client.Client, log.V(1).Info("secret not found", "name", sctName) return nil } - user, ok := secret.Data["operator-user"] - if !ok { - return fmt.Errorf("failed to get the user: %s", user) - } - password, ok := secret.Data["operator-password"] - if !ok { - return fmt.Errorf("failed to get the password: %s", password) - } + rootPasswd, ok := secret.Data["root-password"] if !ok { return fmt.Errorf("failed to get the root password: %s", rootPasswd) @@ -187,18 +183,19 @@ func (s *StatusSyncer) updateNodeStatus(ctx context.Context, cli client.Client, s.updateNodeCondition(node, int(apiv1alpha1.IndexLeader), isLeader) isLagged, isReplicating, isReadOnly := corev1.ConditionUnknown, corev1.ConditionUnknown, corev1.ConditionUnknown - runner, err := internal.NewSQLRunner(utils.BytesToString(user), utils.BytesToString(password), host, port) + sqlRunner, closeConn, err := s.SQLRunnerFactory(internal.NewConfigFromClusterKey(s.cli, s.Cluster.GetClusterKey(), utils.OperatorUser, host)) + defer closeConn() if err != nil { log.Error(err, "failed to connect the mysql", "node", node.Name) node.Message = err.Error() } else { - isLagged, isReplicating, err = runner.CheckSlaveStatusWithRetry(checkNodeStatusRetry) + isLagged, isReplicating, err = sqlRunner.CheckSlaveStatusWithRetry(checkNodeStatusRetry) if err != nil { log.Error(err, "failed to check slave status", "node", node.Name) node.Message = err.Error() } - isReadOnly, err = runner.CheckReadOnly() + isReadOnly, err = sqlRunner.CheckReadOnly() if err != nil { log.Error(err, "failed to check read only", "node", node.Name) node.Message = err.Error() @@ -208,15 +205,11 @@ func (s *StatusSyncer) updateNodeStatus(ctx context.Context, cli client.Client, isLeader == corev1.ConditionTrue && isReadOnly != corev1.ConditionFalse { log.V(1).Info("try to correct the leader writeable", "node", node.Name) - runner.RunQuery("SET GLOBAL read_only=off") - runner.RunQuery("SET GLOBAL super_read_only=off") + sqlRunner.RunQuery("SET GLOBAL read_only=off") + sqlRunner.RunQuery("SET GLOBAL super_read_only=off") } } - if runner != nil { - runner.Close() - } - // update apiv1alpha1.NodeConditionLagged. s.updateNodeCondition(node, int(apiv1alpha1.IndexLagged), isLagged) // update apiv1alpha1.NodeConditionReplicating. diff --git a/cmd/manager/main.go b/cmd/manager/main.go index ee5241c3e..4bee7e5c7 100644 --- a/cmd/manager/main.go +++ b/cmd/manager/main.go @@ -33,6 +33,7 @@ import ( mysqlv1alpha1 "github.com/radondb/radondb-mysql-kubernetes/api/v1alpha1" "github.com/radondb/radondb-mysql-kubernetes/controllers" + "github.com/radondb/radondb-mysql-kubernetes/internal" //+kubebuilder:scaffold:imports ) @@ -79,17 +80,19 @@ func main() { } if err = (&controllers.ClusterReconciler{ - Client: mgr.GetClient(), - Scheme: mgr.GetScheme(), - Recorder: mgr.GetEventRecorderFor("controller.cluster"), + Client: mgr.GetClient(), + Scheme: mgr.GetScheme(), + Recorder: mgr.GetEventRecorderFor("controller.cluster"), + SQLRunnerFactory: internal.NewSQLRunner, }).SetupWithManager(mgr); err != nil { setupLog.Error(err, "unable to create controller", "controller", "Cluster") os.Exit(1) } if err = (&controllers.StatusReconciler{ - Client: mgr.GetClient(), - Scheme: mgr.GetScheme(), - Recorder: mgr.GetEventRecorderFor("controller.status"), + Client: mgr.GetClient(), + Scheme: mgr.GetScheme(), + Recorder: mgr.GetEventRecorderFor("controller.status"), + SQLRunnerFactory: internal.NewSQLRunner, }).SetupWithManager(mgr); err != nil { setupLog.Error(err, "unable to create controller", "controller", "Status") os.Exit(1) diff --git a/controllers/cluster_controller.go b/controllers/cluster_controller.go index fba466acf..0d033ffab 100644 --- a/controllers/cluster_controller.go +++ b/controllers/cluster_controller.go @@ -35,6 +35,7 @@ import ( apiv1alpha1 "github.com/radondb/radondb-mysql-kubernetes/api/v1alpha1" "github.com/radondb/radondb-mysql-kubernetes/cluster" clustersyncer "github.com/radondb/radondb-mysql-kubernetes/cluster/syncer" + "github.com/radondb/radondb-mysql-kubernetes/internal" ) // ClusterReconciler reconciles a Cluster object @@ -42,6 +43,9 @@ type ClusterReconciler struct { client.Client Scheme *runtime.Scheme Recorder record.EventRecorder + + // mysql query runner + internal.SQLRunnerFactory } // +kubebuilder:rbac:groups=mysql.radondb.com,resources=clusters,verbs=get;list;watch;create;update;patch;delete @@ -114,7 +118,7 @@ func (r *ClusterReconciler) Reconcile(ctx context.Context, req ctrl.Request) (ct clustersyncer.NewHeadlessSVCSyncer(r.Client, instance), clustersyncer.NewLeaderSVCSyncer(r.Client, instance), clustersyncer.NewFollowerSVCSyncer(r.Client, instance), - clustersyncer.NewStatefulSetSyncer(r.Client, instance, cmRev, sctRev), + clustersyncer.NewStatefulSetSyncer(r.Client, instance, cmRev, sctRev, r.SQLRunnerFactory), clustersyncer.NewPDBSyncer(r.Client, instance), } diff --git a/controllers/status_controller.go b/controllers/status_controller.go index f1da0d213..32ad5a813 100644 --- a/controllers/status_controller.go +++ b/controllers/status_controller.go @@ -40,6 +40,7 @@ import ( apiv1alpha1 "github.com/radondb/radondb-mysql-kubernetes/api/v1alpha1" "github.com/radondb/radondb-mysql-kubernetes/cluster" clustersyncer "github.com/radondb/radondb-mysql-kubernetes/cluster/syncer" + "github.com/radondb/radondb-mysql-kubernetes/internal" ) // reconcileTimePeriod represents the time in which a cluster should be reconciled @@ -50,6 +51,9 @@ type StatusReconciler struct { client.Client Scheme *runtime.Scheme Recorder record.EventRecorder + + // mysql query runner + internal.SQLRunnerFactory } // Reconcile is part of the main kubernetes reconciliation loop which aims to @@ -88,7 +92,7 @@ func (r *StatusReconciler) Reconcile(ctx context.Context, req ctrl.Request) (ctr } }() - statusSyncer := clustersyncer.NewStatusSyncer(instance, r.Client) + statusSyncer := clustersyncer.NewStatusSyncer(instance, r.Client, r.SQLRunnerFactory) if err := syncer.Sync(ctx, statusSyncer, r.Recorder); err != nil { return ctrl.Result{}, err } diff --git a/internal/query.go b/internal/query.go new file mode 100644 index 000000000..b43e5f76d --- /dev/null +++ b/internal/query.go @@ -0,0 +1,79 @@ +/* +Copyright 2021 RadonDB. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package internal + +import "strings" + +// Query contains a escaped query string with variables marked with a question mark (?) and a slice +// of positional arguments. +type Query struct { + escapedQuery string + args []interface{} +} + +// String representation of the query. +func (q *Query) String() string { + return q.escapedQuery +} + +// Args representation of the query. +func (q *Query) Args() []interface{} { + return q.args +} + +// NewQuery returns a new Query object. +func NewQuery(q string, args ...interface{}) Query { + if q == "" { + panic("unexpected empty query") + } + + if !strings.HasSuffix(q, ";") { + q += ";" + } + + return Query{ + escapedQuery: q, + args: args, + } +} + +// ConcatenateQueries concatenates the provided queries into a single query. +func ConcatenateQueries(queries ...Query) Query { + args := []interface{}{} + query := "" + + for _, pq := range queries { + if query != "" { + if !strings.HasSuffix(query, "\n") { + query += "\n" + } + } + + query += pq.escapedQuery + args = append(args, pq.args...) + } + + return NewQuery(query, args...) +} + +// BuildAtomicQuery concatenates the provided queries into a single query wrapped in a BEGIN COMMIT block. +func BuildAtomicQuery(queries ...Query) Query { + queries = append([]Query{NewQuery("BEGIN")}, queries...) + queries = append(queries, NewQuery("COMMIT")) + + return ConcatenateQueries(queries...) +} diff --git a/internal/sql_runner.go b/internal/sql_runner.go index 074824699..5bce0a9a0 100644 --- a/internal/sql_runner.go +++ b/internal/sql_runner.go @@ -17,7 +17,9 @@ limitations under the License. package internal import ( + "context" "database/sql" + "errors" "fmt" "strconv" "strings" @@ -25,7 +27,11 @@ import ( _ "github.com/go-sql-driver/mysql" corev1 "k8s.io/api/core/v1" + "sigs.k8s.io/controller-runtime/pkg/client" + "sigs.k8s.io/controller-runtime/pkg/log" + apiv1alpha1 "github.com/radondb/radondb-mysql-kubernetes/api/v1alpha1" + mysqlcluster "github.com/radondb/radondb-mysql-kubernetes/cluster" "github.com/radondb/radondb-mysql-kubernetes/utils" ) @@ -38,30 +44,126 @@ var errorConnectionStates = []string{ "waiting to reconnect after a failed master event read", } -// SQLRunner is a runner for run the sql. -type SQLRunner struct { - db *sql.DB +var internalLog = log.Log.WithName("mysql-internal") + +// Config is used to connect to a MysqlCluster. +type Config struct { + User string + Password string + Host string + Port int32 } -// NewSQLRunner return a pointer to SQLRunner. -func NewSQLRunner(user, password, host string, port int) (*SQLRunner, error) { - dataSourceName := fmt.Sprintf("%s:%s@tcp(%s:%d)/?timeout=5s&interpolateParams=true&multiStatements=true", - user, password, host, port, - ) - db, err := sql.Open("mysql", dataSourceName) - if err != nil { +// NewConfigFromClusterKey returns a new Config based on a MySQLCluster key. +func NewConfigFromClusterKey(c client.Client, clusterKey client.ObjectKey, userName, host string) (*Config, error) { + cluster := &apiv1alpha1.Cluster{} + if err := c.Get(context.TODO(), clusterKey, cluster); err != nil { return nil, err } - if err = db.Ping(); err != nil { + secret := &corev1.Secret{} + secretKey := client.ObjectKey{Name: mysqlcluster.New(cluster).GetNameForResource(utils.Secret), Namespace: cluster.Namespace} + + if err := c.Get(context.TODO(), secretKey, secret); err != nil { return nil, err } - return &SQLRunner{db}, nil + if host == string(utils.LeaderHost) { + host = fmt.Sprintf("%s-leader.%s", cluster.Name, cluster.Namespace) + } + + switch userName { + case utils.OperatorUser: + return &Config{ + User: utils.OperatorUser, + Password: string(secret.Data["operator-password"]), + Host: host, + Port: 3306, + }, nil + case utils.RootUser: + return &Config{ + User: utils.RootUser, + Password: string(secret.Data["root-password"]), + Host: host, + Port: 3306, + }, nil + default: + return nil, fmt.Errorf("failed to get the configuration of sqlrunner") + } + +} + +// GetMysqlDSN returns a data source name. +func (c *Config) GetMysqlDSN() string { + return fmt.Sprintf("%s:%s@tcp(%s:%d)/?timeout=5s&multiStatements=true&interpolateParams=true", + c.User, c.Password, c.Host, c.Port, + ) +} + +type sqlRunner struct { + db *sql.DB +} + +// SQLRunner interface is a subset of mysql.DB. +type SQLRunner interface { + RunQuery(query string, args ...interface{}) error + ClusterOperation + UserOperation +} + +// ClusterOperation interface contains sql operations for cluster. +type ClusterOperation interface { + CheckSlaveStatusWithRetry(retry uint32) (isLagged, isReplicating corev1.ConditionStatus, err error) + CheckReadOnly() (corev1.ConditionStatus, error) + GetGlobalVariable(param string, val interface{}) error + CheckProcesslist() (bool, error) +} + +// UserOperation interface contains sql operations for mysql user. +type UserOperation interface { + CreateUserIfNotExists(user, pass string, allowedHosts []string, permissions []apiv1alpha1.UserPermission) error + DropUser(user, host string) error +} + +// SQLRunnerFactory a function that generates a new SQLRunner. +type SQLRunnerFactory func(cfg *Config, errs ...error) (SQLRunner, func(), error) + +// NewSQLRunner opens a connections using the given DSN. +func NewSQLRunner(cfg *Config, errs ...error) (SQLRunner, func(), error) { + var db *sql.DB + var close func() + + // Make this factory accept a functions that tries to generate a config. + if len(errs) > 0 && errs[0] != nil { + return nil, close, errs[0] + } + + db, err := sql.Open("mysql", cfg.GetMysqlDSN()) + if err != nil { + return nil, close, err + } + + // Close connection function. + close = func() { + if cErr := db.Close(); cErr != nil { + internalLog.Error(cErr, "failed closing the database connection") + } + } + + return &sqlRunner{db: db}, close, nil +} + +// RunQuery used to run the query with args. +func (s sqlRunner) RunQuery(query string, args ...interface{}) error { + if _, err := s.db.Exec(query, args...); err != nil { + return err + } + + return nil } // CheckSlaveStatusWithRetry check the slave status with retry time. -func (s *SQLRunner) CheckSlaveStatusWithRetry(retry uint32) (isLagged, isReplicating corev1.ConditionStatus, err error) { +func (s sqlRunner) CheckSlaveStatusWithRetry(retry uint32) (isLagged, isReplicating corev1.ConditionStatus, err error) { for { if retry == 0 { break @@ -79,7 +181,7 @@ func (s *SQLRunner) CheckSlaveStatusWithRetry(retry uint32) (isLagged, isReplica } // checkSlaveStatus check the slave status. -func (s *SQLRunner) checkSlaveStatus() (isLagged, isReplicating corev1.ConditionStatus, err error) { +func (s sqlRunner) checkSlaveStatus() (isLagged, isReplicating corev1.ConditionStatus, err error) { var rows *sql.Rows isLagged, isReplicating = corev1.ConditionUnknown, corev1.ConditionUnknown rows, err = s.db.Query("show slave status;") @@ -144,7 +246,7 @@ func (s *SQLRunner) checkSlaveStatus() (isLagged, isReplicating corev1.Condition } // CheckReadOnly check whether the mysql is read only. -func (s *SQLRunner) CheckReadOnly() (corev1.ConditionStatus, error) { +func (s sqlRunner) CheckReadOnly() (corev1.ConditionStatus, error) { var readOnly uint8 if err := s.GetGlobalVariable("read_only", &readOnly); err != nil { return corev1.ConditionUnknown, err @@ -157,22 +259,13 @@ func (s *SQLRunner) CheckReadOnly() (corev1.ConditionStatus, error) { return corev1.ConditionTrue, nil } -// RunQuery used to run the query with args. -func (s *SQLRunner) RunQuery(query string, args ...interface{}) error { - if _, err := s.db.Exec(query, args...); err != nil { - return err - } - - return nil -} - // GetGlobalVariable used to get the global variable by param. -func (s *SQLRunner) GetGlobalVariable(param string, val interface{}) error { +func (s sqlRunner) GetGlobalVariable(param string, val interface{}) error { query := fmt.Sprintf("select @@global.%s", param) return s.db.QueryRow(query).Scan(val) } -func (s *SQLRunner) CheckProcesslist() (bool, error) { +func (s sqlRunner) CheckProcesslist() (bool, error) { var rows *sql.Rows rows, err := s.db.Query("show processlist;") if err != nil { @@ -205,11 +298,6 @@ func (s *SQLRunner) CheckProcesslist() (bool, error) { return false, nil } -// Close closes the database and prevents new queries from starting. -func (sr *SQLRunner) Close() error { - return sr.db.Close() -} - // columnValue get the column value. func columnValue(scanArgs []interface{}, slaveCols []string, colName string) string { columnIndex := -1 @@ -226,3 +314,139 @@ func columnValue(scanArgs []interface{}, slaveCols []string, colName string) str return string(*scanArgs[columnIndex].(*sql.RawBytes)) } + +// CreateUserIfNotExists creates a user if it doesn't already exist and it gives it the specified permissions. +func (s sqlRunner) CreateUserIfNotExists( + user, pass string, allowedHosts []string, permissions []apiv1alpha1.UserPermission, +) error { + + // Throw error if there are no allowed hosts. + if len(allowedHosts) == 0 { + return errors.New("no allowedHosts specified") + } + + queries := []Query{ + getCreateUserQuery(user, pass, allowedHosts), + // todo: getAlterUserQuery + } + + if len(permissions) > 0 { + queries = append(queries, permissionsToQuery(permissions, user, allowedHosts)) + } + + query := BuildAtomicQuery(queries...) + + if err := s.RunQuery(query.String(), query.args...); err != nil { + return fmt.Errorf("failed to configure user (user/pass/access), err: %s", err) + } + + return nil +} + +func getCreateUserQuery(user, pwd string, allowedHosts []string) Query { + idsTmpl, idsArgs := getUsersIdentification(user, &pwd, allowedHosts) + + return NewQuery(fmt.Sprintf("CREATE USER IF NOT EXISTS%s", idsTmpl), idsArgs...) +} + +func getUsersIdentification(user string, pwd *string, allowedHosts []string) (ids string, args []interface{}) { + for i, host := range allowedHosts { + // Add comma if more than one allowed hosts are used. + if i > 0 { + ids += "," + } + + if pwd != nil { + ids += " ?@? IDENTIFIED BY ?" + args = append(args, user, host, *pwd) + } else { + ids += " ?@?" + args = append(args, user, host) + } + } + + return ids, args +} + +// DropUser removes a MySQL user if it exists, along with its privileges. +func (s sqlRunner) DropUser(user, host string) error { + query := NewQuery("DROP USER IF EXISTS ?@?;", user, host) + + if err := s.RunQuery(query.String(), query.args...); err != nil { + return fmt.Errorf("failed to delete user, err: %s", err) + } + + return nil +} + +func permissionsToQuery(permissions []apiv1alpha1.UserPermission, user string, allowedHosts []string) Query { + permQueries := []Query{} + + for _, perm := range permissions { + // If you wish to grant permissions on all tables, you should explicitly use "*". + for _, table := range perm.Tables { + args := []interface{}{} + + escPerms := []string{} + for _, perm := range perm.Privileges { + escPerms = append(escPerms, Escape(perm)) + } + + schemaTable := fmt.Sprintf("%s.%s", escapeID(perm.Database), escapeID(table)) + + // Build GRANT query. + idsTmpl, idsArgs := getUsersIdentification(user, nil, allowedHosts) + + query := "GRANT " + strings.Join(escPerms, ", ") + " ON " + schemaTable + " TO" + idsTmpl + args = append(args, idsArgs...) + + permQueries = append(permQueries, NewQuery(query, args...)) + } + } + + return ConcatenateQueries(permQueries...) +} + +func escapeID(id string) string { + if id == "*" { + return id + } + + // don't allow using ` in id name + id = strings.ReplaceAll(id, "`", "") + + return fmt.Sprintf("`%s`", id) +} + +// Escape escapes a string. +func Escape(sql string) string { + dest := make([]byte, 0, 2*len(sql)) + var escape byte + for i := 0; i < len(sql); i++ { + escape = 0 + switch sql[i] { + case 0: /* Must be escaped for 'mysql' */ + escape = '0' + case '\n': /* Must be escaped for logs */ + escape = 'n' + case '\r': + escape = 'r' + case '\\': + escape = '\\' + case '\'': + escape = '\'' + case '"': /* Better safe than sorry */ + escape = '"' + case '\032': /* This gives problems on Win32 */ + escape = 'Z' + } + + if escape != 0 { + dest = append(dest, '\\', escape) + } else { + dest = append(dest, sql[i]) + } + } + + return string(dest) +} diff --git a/utils/constants.go b/utils/constants.go index df3fbb2fb..3b077b37f 100644 --- a/utils/constants.go +++ b/utils/constants.go @@ -66,6 +66,8 @@ const ( MetricsUser = "qc_metrics" // The MySQL user used for operator to connect to the mysql node for configuration. OperatorUser = "qc_operator" + // The name of the MySQL root user. + RootUser = "root" // xtrabackup http server user BackupUser = "sys_backup" @@ -124,6 +126,10 @@ const ( ServiceAccount ResourceName = "service-account" // PodDisruptionBudget is the name of pod disruption budget for the statefulset. PodDisruptionBudget ResourceName = "pdb" + // UserControllerName is the name of UserController. + UserControllerName ResourceName = "controller.user" + // LeaderNode is the alias for leader`s host. + LeaderHost ResourceName = "leader-host" ) // JobType