Skip to content

Commit

Permalink
tctl - Add --set flags for every trait (#14552)
Browse files Browse the repository at this point in the history
This change adds a --set flag to tctl users update for every trait where it wasn't implemented already (Windows logins, Kube users/groups, DB users/names, and AWS role ARNs).
  • Loading branch information
atburke authored Aug 1, 2022
1 parent 0a0eb9d commit 4653503
Show file tree
Hide file tree
Showing 3 changed files with 221 additions and 74 deletions.
78 changes: 33 additions & 45 deletions tool/tctl/common/helpers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import (
"testing"
"time"

"github.com/gravitational/kingpin"
"github.com/sirupsen/logrus"
"github.com/stretchr/testify/require"
"gopkg.in/yaml.v2"
Expand Down Expand Up @@ -59,23 +60,12 @@ func withInsecure(insecure bool) optionsFunc {
}
}

func runResourceCommand(t *testing.T, fc *config.FileConfig, args []string, opts ...optionsFunc) (*bytes.Buffer, error) {
func getAuthClient(ctx context.Context, t *testing.T, fc *config.FileConfig, opts ...optionsFunc) auth.ClientI {
var options options
for _, v := range opts {
v(&options)
}
var stdoutBuff bytes.Buffer
command := &ResourceCommand{
stdout: &stdoutBuff,
}
cfg := service.MakeDefaultConfig()
cfg.CircuitBreakerConfig = breaker.NoopBreakerConfig()

app := utils.InitCLIParser("tctl", GlobalHelpString)
command.Initialize(app, cfg)

selectedCmd, err := app.Parse(args)
require.NoError(t, err)

var ccf GlobalCLIFlags
ccf.ConfigString = mustGetBase64EncFileConfig(t, fc)
Expand All @@ -88,56 +78,54 @@ func runResourceCommand(t *testing.T, fc *config.FileConfig, args []string, opts
clientConfig.TLS.RootCAs = options.CertPool
}

ctx := context.Background()
client, err := authclient.Connect(ctx, clientConfig)
require.NoError(t, err)

_, err = command.TryRun(ctx, selectedCmd, client)
if err != nil {
return nil, err
}
return &stdoutBuff, nil
return client
}

func runTokensCommand(t *testing.T, fc *config.FileConfig, args []string, opts ...optionsFunc) (*bytes.Buffer, error) {
var options options
for _, v := range opts {
v(&options)
}
type cliCommand interface {
Initialize(app *kingpin.Application, cfg *service.Config)
TryRun(ctx context.Context, cmd string, client auth.ClientI) (bool, error)
}

var stdoutBuff bytes.Buffer
command := &TokensCommand{
stdout: &stdoutBuff,
}
func runCommand(t *testing.T, fc *config.FileConfig, cmd cliCommand, args []string, opts ...optionsFunc) error {
cfg := service.MakeDefaultConfig()
cfg.CircuitBreakerConfig = breaker.NoopBreakerConfig()

app := utils.InitCLIParser("tctl", GlobalHelpString)
command.Initialize(app, cfg)
cmd.Initialize(app, cfg)

args = append([]string{"tokens"}, args...)
selectedCmd, err := app.Parse(args)
require.NoError(t, err)

var ccf GlobalCLIFlags
ccf.ConfigString = mustGetBase64EncFileConfig(t, fc)
ccf.Insecure = options.Insecure
ctx := context.Background()
client := getAuthClient(ctx, t, fc, opts...)
_, err = cmd.TryRun(ctx, selectedCmd, client)
return err
}

clientConfig, err := applyConfig(&ccf, cfg)
require.NoError(t, err)
func runResourceCommand(t *testing.T, fc *config.FileConfig, args []string, opts ...optionsFunc) (*bytes.Buffer, error) {
var stdoutBuff bytes.Buffer
command := &ResourceCommand{
stdout: &stdoutBuff,
}
return &stdoutBuff, runCommand(t, fc, command, args, opts...)
}

if options.CertPool != nil {
clientConfig.TLS.RootCAs = options.CertPool
func runTokensCommand(t *testing.T, fc *config.FileConfig, args []string, opts ...optionsFunc) (*bytes.Buffer, error) {
var stdoutBuff bytes.Buffer
command := &TokensCommand{
stdout: &stdoutBuff,
}

ctx := context.Background()
client, err := authclient.Connect(ctx, clientConfig)
require.NoError(t, err)
args = append([]string{"tokens"}, args...)
return &stdoutBuff, runCommand(t, fc, command, args, opts...)
}

_, err = command.TryRun(ctx, selectedCmd, client)
if err != nil {
return nil, err
}
return &stdoutBuff, nil
func runUserCommand(t *testing.T, fc *config.FileConfig, args []string, opts ...optionsFunc) error {
command := &UserCommand{}
args = append([]string{"users"}, args...)
return runCommand(t, fc, command, args, opts...)
}

func mustDecodeJSON(t *testing.T, r io.Reader, i interface{}) {
Expand Down
95 changes: 66 additions & 29 deletions tool/tctl/common/user_command.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,15 +47,10 @@ type UserCommand struct {
allowedDatabaseUsers []string
allowedDatabaseNames []string
allowedAWSRoleARNs []string
createRoles []string
allowedRoles []string

ttl time.Duration

// updateRoles contains new roles for update users command
updateRoles string
// updateLogins contains new logins for update users command
updateLogins string

// format is the output format, e.g. text or json
format string

Expand Down Expand Up @@ -84,7 +79,7 @@ func (u *UserCommand) Initialize(app *kingpin.Application, config *service.Confi
u.userAdd.Flag("db-names", "List of allowed database names for the new user").StringsVar(&u.allowedDatabaseNames)
u.userAdd.Flag("aws-role-arns", "List of allowed AWS role ARNs for the new user").StringsVar(&u.allowedAWSRoleARNs)

u.userAdd.Flag("roles", "List of roles for the new user to assume").Required().StringsVar(&u.createRoles)
u.userAdd.Flag("roles", "List of roles for the new user to assume").Required().StringsVar(&u.allowedRoles)

u.userAdd.Flag("ttl", fmt.Sprintf("Set expiration time for token, default is %v, maximum is %v",
defaults.SignupTokenTTL, defaults.MaxSignupTokenTTL)).
Expand All @@ -95,9 +90,21 @@ func (u *UserCommand) Initialize(app *kingpin.Application, config *service.Confi
u.userUpdate = users.Command("update", "Update user account")
u.userUpdate.Arg("account", "Teleport user account name").Required().StringVar(&u.login)
u.userUpdate.Flag("set-roles", "List of roles for the user to assume, replaces current roles").
Default("").StringVar(&u.updateRoles)
u.userUpdate.Flag("set-logins", "List of SSH logins for the user, replaces current logins").
Default("").StringVar(&u.updateLogins)
StringsVar(&u.allowedRoles)
u.userUpdate.Flag("set-logins", "List of allowed SSH logins for the user, replaces current logins").
StringsVar(&u.allowedLogins)
u.userUpdate.Flag("set-windows-logins", "List of allowed Windows logins for the user, replaces current Windows logins").
StringsVar(&u.allowedWindowsLogins)
u.userUpdate.Flag("set-kubernetes-users", "List of allowed Kubernetes users for the user, replaces current Kubernetes users").
StringsVar(&u.allowedKubeUsers)
u.userUpdate.Flag("set-kubernetes-groups", "List of allowed Kubernetes groups for the user, replaces current Kubernetes groups").
StringsVar(&u.allowedKubeGroups)
u.userUpdate.Flag("set-db-users", "List of allowed database users for the user, replaces current database users").
StringsVar(&u.allowedDatabaseUsers)
u.userUpdate.Flag("set-db-names", "List of allowed database names for the user, replaces current database names").
StringsVar(&u.allowedDatabaseNames)
u.userUpdate.Flag("set-aws-role-arns", "List of allowed AWS role ARNs for the user, replaces current AWS role ARNs").
StringsVar(&u.allowedAWSRoleARNs)

u.userList = users.Command("ls", "Lists all user accounts.")
u.userList.Flag("format", "Output format, 'text' or 'json'").Hidden().Default(teleport.Text).StringVar(&u.format)
Expand Down Expand Up @@ -199,13 +206,12 @@ func (u *UserCommand) printResetPasswordToken(token types.UserToken, format stri
// Add implements `tctl users add` for the enterprise edition. Unlike the OSS
// version, this one requires --roles flag to be set
func (u *UserCommand) Add(ctx context.Context, client auth.ClientI) error {
u.createRoles = flattenSlice(u.createRoles)
u.allowedRoles = flattenSlice(u.allowedRoles)
u.allowedLogins = flattenSlice(u.allowedLogins)
u.allowedWindowsLogins = flattenSlice(u.allowedWindowsLogins)

// Validate roles
// DELETE IN 12.0.0
for _, roleName := range u.createRoles {
// Validate roles (server does not do this yet).
for _, roleName := range u.allowedRoles {
if _, err := client.GetRole(ctx, roleName); err != nil {
return trace.Wrap(err)
}
Expand All @@ -227,7 +233,7 @@ func (u *UserCommand) Add(ctx context.Context, client auth.ClientI) error {
}

user.SetTraits(traits)
user.SetRoles(u.createRoles)
user.SetRoles(u.allowedRoles)

if err := client.CreateUser(ctx, user); err != nil {
return trace.Wrap(err)
Expand Down Expand Up @@ -283,38 +289,69 @@ func printTokenAsText(token types.UserToken, messageFormat string) error {

// Update updates existing user
func (u *UserCommand) Update(ctx context.Context, client auth.ClientI) error {
if u.updateRoles == "" && u.updateLogins == "" {
return trace.BadParameter("Nothing to update. Please provide --set-roles or --set-logins flag.")
}
user, err := client.GetUser(u.login, false)
if err != nil {
return trace.Wrap(err)
}

var updateMessages []string
if u.updateRoles != "" {
roles := flattenSlice([]string{u.updateRoles})
updateMessages := make(map[string][]string)
if len(u.allowedRoles) > 0 {
roles := flattenSlice(u.allowedRoles)
for _, role := range roles {
if _, err := client.GetRole(ctx, role); err != nil {
return trace.Wrap(err)
}
}
user.SetRoles(roles)
updateMessages = append(updateMessages, "with roles "+strings.Join(user.GetRoles(), ","))
updateMessages["roles"] = roles
}
if len(u.allowedLogins) > 0 {
logins := flattenSlice(u.allowedLogins)
user.SetLogins(logins)
updateMessages["logins"] = logins
}
if len(u.allowedWindowsLogins) > 0 {
windowsLogins := flattenSlice(u.allowedWindowsLogins)
user.SetWindowsLogins(windowsLogins)
updateMessages["Windows logins"] = windowsLogins
}
if len(u.allowedKubeUsers) > 0 {
kubeUsers := flattenSlice(u.allowedKubeUsers)
user.SetKubeUsers(kubeUsers)
updateMessages["Kubernetes users"] = kubeUsers
}
if len(u.allowedKubeGroups) > 0 {
kubeGroups := flattenSlice(u.allowedKubeGroups)
user.SetKubeGroups(kubeGroups)
updateMessages["Kubernetes groups"] = kubeGroups
}
if len(u.allowedDatabaseUsers) > 0 {
dbUsers := flattenSlice(u.allowedDatabaseUsers)
user.SetDatabaseUsers(dbUsers)
updateMessages["database users"] = dbUsers
}
if len(u.allowedDatabaseNames) > 0 {
dbNames := flattenSlice(u.allowedDatabaseNames)
user.SetDatabaseNames(dbNames)
updateMessages["database names"] = dbNames
}
if len(u.allowedAWSRoleARNs) > 0 {
awsRoleARNs := flattenSlice(u.allowedAWSRoleARNs)
user.SetAWSRoleARNs(awsRoleARNs)
updateMessages["AWS role ARNs"] = awsRoleARNs
}

if u.updateLogins != "" {
logins := flattenSlice([]string{u.updateLogins})
traits := user.GetTraits()
traits[constants.TraitLogins] = logins
user.SetTraits(traits)
updateMessages = append(updateMessages, "with logins "+strings.Join(logins, ","))
if len(updateMessages) == 0 {
return trace.BadParameter("Nothing to update. Please provide at least one --set flag.")
}

if err := client.UpsertUser(user); err != nil {
return trace.Wrap(err)
}
fmt.Printf("%v has been updated %v\n", user.GetName(), strings.Join(updateMessages, " and "))
fmt.Printf("User %v has been updated:\n", user.GetName())
for field, values := range updateMessages {
fmt.Printf("\tNew %v: %v\n", field, strings.Join(values, ","))
}
return nil
}

Expand Down
Loading

0 comments on commit 4653503

Please sign in to comment.