Skip to content

Commit

Permalink
*: add user management related operations.
Browse files Browse the repository at this point in the history
  • Loading branch information
runkecheng committed Sep 27, 2021
1 parent f605392 commit 548c01e
Show file tree
Hide file tree
Showing 3 changed files with 165 additions and 0 deletions.
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ require (
github.com/blang/semver v3.5.1+incompatible
github.com/go-ini/ini v1.62.0
github.com/go-sql-driver/mysql v1.6.0
github.com/go-test/deep v1.0.7 // indirect
github.com/iancoleman/strcase v0.0.0-20190422225806-e506e3ef7365
github.com/imdario/mergo v0.3.12
github.com/onsi/ginkgo v1.16.4
Expand Down
27 changes: 27 additions & 0 deletions internal/query.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,3 +53,30 @@ func NewQuery(q string, args ...interface{}) Query {
args: args,
}
}

// ConcatenateQueries concatenates the provided queries into a single query.
func ConcatenateQueries(queries ...Query) Query {
args := []interface{}{}
query := ""

for _, pq := range queries {
if query != "" {
if !strings.HasSuffix(query, "\n") {
query += "\n"
}
}

query += pq.escapedQuery
args = append(args, pq.args...)
}

return NewQuery(query, args...)
}

// BuildAtomicQuery concatenates the provided queries into a single query wrapped in a BEGIN COMMIT block.
func BuildAtomicQuery(queries ...Query) Query {
queries = append([]Query{NewQuery("BEGIN")}, queries...)
queries = append(queries, NewQuery("COMMIT"))

return ConcatenateQueries(queries...)
}
137 changes: 137 additions & 0 deletions internal/sql_runner.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package internal
import (
"context"
"database/sql"
"errors"
"fmt"
"strconv"
"strings"
Expand Down Expand Up @@ -323,3 +324,139 @@ func columnValue(scanArgs []interface{}, slaveCols []string, colName string) str

return string(*scanArgs[columnIndex].(*sql.RawBytes))
}

// CreateUserIfNotExists creates a user if it doesn't already exist and it gives it the specified permissions.
func (s sqlRunner) CreateUserIfNotExists(
user, pass string, allowedHosts []string, permissions []apiv1alpha1.UserPermission,
) error {

// Throw error if there are no allowed hosts.
if len(allowedHosts) == 0 {
return errors.New("no allowedHosts specified")
}

queries := []Query{
getCreateUserQuery(user, pass, allowedHosts),
// todo: getAlterUserQuery
}

if len(permissions) > 0 {
queries = append(queries, permissionsToQuery(permissions, user, allowedHosts))
}

query := BuildAtomicQuery(queries...)

if err := s.QueryExec(query); err != nil {
return fmt.Errorf("failed to configure user (user/pass/access), err: %s", err)
}

return nil
}

func getCreateUserQuery(user, pwd string, allowedHosts []string) Query {
idsTmpl, idsArgs := getUsersIdentification(user, &pwd, allowedHosts)

return NewQuery(fmt.Sprintf("CREATE USER IF NOT EXISTS%s", idsTmpl), idsArgs...)
}

func getUsersIdentification(user string, pwd *string, allowedHosts []string) (ids string, args []interface{}) {
for i, host := range allowedHosts {
// Add comma if more than one allowed hosts are used.
if i > 0 {
ids += ","
}

if pwd != nil {
ids += " ?@? IDENTIFIED BY ?"
args = append(args, user, host, *pwd)
} else {
ids += " ?@?"
args = append(args, user, host)
}
}

return ids, args
}

// DropUser removes a MySQL user if it exists, along with its privileges.
func (s sqlRunner) DropUser(user, host string) error {
query := NewQuery("DROP USER IF EXISTS ?@?;", user, host)

if err := s.QueryExec(query); err != nil {
return fmt.Errorf("failed to delete user, err: %s", err)
}

return nil
}

func permissionsToQuery(permissions []apiv1alpha1.UserPermission, user string, allowedHosts []string) Query {
permQueries := []Query{}

for _, perm := range permissions {
// If you wish to grant permissions on all tables, you should explicitly use "*".
for _, table := range perm.Tables {
args := []interface{}{}

escPerms := []string{}
for _, perm := range perm.Privileges {
escPerms = append(escPerms, Escape(perm))
}

schemaTable := fmt.Sprintf("%s.%s", escapeID(perm.Database), escapeID(table))

// Build GRANT query.
idsTmpl, idsArgs := getUsersIdentification(user, nil, allowedHosts)

query := "GRANT " + strings.Join(escPerms, ", ") + " ON " + schemaTable + " TO" + idsTmpl
args = append(args, idsArgs...)

permQueries = append(permQueries, NewQuery(query, args...))
}
}

return ConcatenateQueries(permQueries...)
}

func escapeID(id string) string {
if id == "*" {
return id
}

// don't allow using ` in id name
id = strings.ReplaceAll(id, "`", "")

return fmt.Sprintf("`%s`", id)
}

// Escape escapes a string.
func Escape(sql string) string {
dest := make([]byte, 0, 2*len(sql))
var escape byte
for i := 0; i < len(sql); i++ {
escape = 0
switch sql[i] {
case 0: /* Must be escaped for 'mysql' */
escape = '0'
case '\n': /* Must be escaped for logs */
escape = 'n'
case '\r':
escape = 'r'
case '\\':
escape = '\\'
case '\'':
escape = '\''
case '"': /* Better safe than sorry */
escape = '"'
case '\032': /* This gives problems on Win32 */
escape = 'Z'
}

if escape != 0 {
dest = append(dest, '\\', escape)
} else {
dest = append(dest, sql[i])
}
}

return string(dest)
}

0 comments on commit 548c01e

Please sign in to comment.