From 465350336719d6383df76e155b6c529ca31c321a Mon Sep 17 00:00:00 2001 From: Andrew Burke <31974658+atburke@users.noreply.github.com> Date: Mon, 1 Aug 2022 16:30:43 -0700 Subject: [PATCH] `tctl` - Add --set flags for every trait (#14552) This change adds a --set flag to tctl users update for every trait where it wasn't implemented already (Windows logins, Kube users/groups, DB users/names, and AWS role ARNs). --- tool/tctl/common/helpers_test.go | 78 +++++++--------- tool/tctl/common/user_command.go | 95 ++++++++++++++------ tool/tctl/common/user_command_test.go | 122 ++++++++++++++++++++++++++ 3 files changed, 221 insertions(+), 74 deletions(-) diff --git a/tool/tctl/common/helpers_test.go b/tool/tctl/common/helpers_test.go index 620eae1ca4f5f..8314b4b7e7944 100644 --- a/tool/tctl/common/helpers_test.go +++ b/tool/tctl/common/helpers_test.go @@ -27,6 +27,7 @@ import ( "testing" "time" + "github.com/gravitational/kingpin" "github.com/sirupsen/logrus" "github.com/stretchr/testify/require" "gopkg.in/yaml.v2" @@ -59,23 +60,12 @@ func withInsecure(insecure bool) optionsFunc { } } -func runResourceCommand(t *testing.T, fc *config.FileConfig, args []string, opts ...optionsFunc) (*bytes.Buffer, error) { +func getAuthClient(ctx context.Context, t *testing.T, fc *config.FileConfig, opts ...optionsFunc) auth.ClientI { var options options for _, v := range opts { v(&options) } - var stdoutBuff bytes.Buffer - command := &ResourceCommand{ - stdout: &stdoutBuff, - } cfg := service.MakeDefaultConfig() - cfg.CircuitBreakerConfig = breaker.NoopBreakerConfig() - - app := utils.InitCLIParser("tctl", GlobalHelpString) - command.Initialize(app, cfg) - - selectedCmd, err := app.Parse(args) - require.NoError(t, err) var ccf GlobalCLIFlags ccf.ConfigString = mustGetBase64EncFileConfig(t, fc) @@ -88,56 +78,54 @@ func runResourceCommand(t *testing.T, fc *config.FileConfig, args []string, opts clientConfig.TLS.RootCAs = options.CertPool } - ctx := context.Background() client, err := authclient.Connect(ctx, clientConfig) require.NoError(t, err) - - _, err = command.TryRun(ctx, selectedCmd, client) - if err != nil { - return nil, err - } - return &stdoutBuff, nil + return client } -func runTokensCommand(t *testing.T, fc *config.FileConfig, args []string, opts ...optionsFunc) (*bytes.Buffer, error) { - var options options - for _, v := range opts { - v(&options) - } +type cliCommand interface { + Initialize(app *kingpin.Application, cfg *service.Config) + TryRun(ctx context.Context, cmd string, client auth.ClientI) (bool, error) +} - var stdoutBuff bytes.Buffer - command := &TokensCommand{ - stdout: &stdoutBuff, - } +func runCommand(t *testing.T, fc *config.FileConfig, cmd cliCommand, args []string, opts ...optionsFunc) error { cfg := service.MakeDefaultConfig() + cfg.CircuitBreakerConfig = breaker.NoopBreakerConfig() app := utils.InitCLIParser("tctl", GlobalHelpString) - command.Initialize(app, cfg) + cmd.Initialize(app, cfg) - args = append([]string{"tokens"}, args...) selectedCmd, err := app.Parse(args) require.NoError(t, err) - var ccf GlobalCLIFlags - ccf.ConfigString = mustGetBase64EncFileConfig(t, fc) - ccf.Insecure = options.Insecure + ctx := context.Background() + client := getAuthClient(ctx, t, fc, opts...) + _, err = cmd.TryRun(ctx, selectedCmd, client) + return err +} - clientConfig, err := applyConfig(&ccf, cfg) - require.NoError(t, err) +func runResourceCommand(t *testing.T, fc *config.FileConfig, args []string, opts ...optionsFunc) (*bytes.Buffer, error) { + var stdoutBuff bytes.Buffer + command := &ResourceCommand{ + stdout: &stdoutBuff, + } + return &stdoutBuff, runCommand(t, fc, command, args, opts...) +} - if options.CertPool != nil { - clientConfig.TLS.RootCAs = options.CertPool +func runTokensCommand(t *testing.T, fc *config.FileConfig, args []string, opts ...optionsFunc) (*bytes.Buffer, error) { + var stdoutBuff bytes.Buffer + command := &TokensCommand{ + stdout: &stdoutBuff, } - ctx := context.Background() - client, err := authclient.Connect(ctx, clientConfig) - require.NoError(t, err) + args = append([]string{"tokens"}, args...) + return &stdoutBuff, runCommand(t, fc, command, args, opts...) +} - _, err = command.TryRun(ctx, selectedCmd, client) - if err != nil { - return nil, err - } - return &stdoutBuff, nil +func runUserCommand(t *testing.T, fc *config.FileConfig, args []string, opts ...optionsFunc) error { + command := &UserCommand{} + args = append([]string{"users"}, args...) + return runCommand(t, fc, command, args, opts...) } func mustDecodeJSON(t *testing.T, r io.Reader, i interface{}) { diff --git a/tool/tctl/common/user_command.go b/tool/tctl/common/user_command.go index 25145e830205b..7435d36ab92f8 100644 --- a/tool/tctl/common/user_command.go +++ b/tool/tctl/common/user_command.go @@ -47,15 +47,10 @@ type UserCommand struct { allowedDatabaseUsers []string allowedDatabaseNames []string allowedAWSRoleARNs []string - createRoles []string + allowedRoles []string ttl time.Duration - // updateRoles contains new roles for update users command - updateRoles string - // updateLogins contains new logins for update users command - updateLogins string - // format is the output format, e.g. text or json format string @@ -84,7 +79,7 @@ func (u *UserCommand) Initialize(app *kingpin.Application, config *service.Confi u.userAdd.Flag("db-names", "List of allowed database names for the new user").StringsVar(&u.allowedDatabaseNames) u.userAdd.Flag("aws-role-arns", "List of allowed AWS role ARNs for the new user").StringsVar(&u.allowedAWSRoleARNs) - u.userAdd.Flag("roles", "List of roles for the new user to assume").Required().StringsVar(&u.createRoles) + u.userAdd.Flag("roles", "List of roles for the new user to assume").Required().StringsVar(&u.allowedRoles) u.userAdd.Flag("ttl", fmt.Sprintf("Set expiration time for token, default is %v, maximum is %v", defaults.SignupTokenTTL, defaults.MaxSignupTokenTTL)). @@ -95,9 +90,21 @@ func (u *UserCommand) Initialize(app *kingpin.Application, config *service.Confi u.userUpdate = users.Command("update", "Update user account") u.userUpdate.Arg("account", "Teleport user account name").Required().StringVar(&u.login) u.userUpdate.Flag("set-roles", "List of roles for the user to assume, replaces current roles"). - Default("").StringVar(&u.updateRoles) - u.userUpdate.Flag("set-logins", "List of SSH logins for the user, replaces current logins"). - Default("").StringVar(&u.updateLogins) + StringsVar(&u.allowedRoles) + u.userUpdate.Flag("set-logins", "List of allowed SSH logins for the user, replaces current logins"). + StringsVar(&u.allowedLogins) + u.userUpdate.Flag("set-windows-logins", "List of allowed Windows logins for the user, replaces current Windows logins"). + StringsVar(&u.allowedWindowsLogins) + u.userUpdate.Flag("set-kubernetes-users", "List of allowed Kubernetes users for the user, replaces current Kubernetes users"). + StringsVar(&u.allowedKubeUsers) + u.userUpdate.Flag("set-kubernetes-groups", "List of allowed Kubernetes groups for the user, replaces current Kubernetes groups"). + StringsVar(&u.allowedKubeGroups) + u.userUpdate.Flag("set-db-users", "List of allowed database users for the user, replaces current database users"). + StringsVar(&u.allowedDatabaseUsers) + u.userUpdate.Flag("set-db-names", "List of allowed database names for the user, replaces current database names"). + StringsVar(&u.allowedDatabaseNames) + u.userUpdate.Flag("set-aws-role-arns", "List of allowed AWS role ARNs for the user, replaces current AWS role ARNs"). + StringsVar(&u.allowedAWSRoleARNs) u.userList = users.Command("ls", "Lists all user accounts.") u.userList.Flag("format", "Output format, 'text' or 'json'").Hidden().Default(teleport.Text).StringVar(&u.format) @@ -199,13 +206,12 @@ func (u *UserCommand) printResetPasswordToken(token types.UserToken, format stri // Add implements `tctl users add` for the enterprise edition. Unlike the OSS // version, this one requires --roles flag to be set func (u *UserCommand) Add(ctx context.Context, client auth.ClientI) error { - u.createRoles = flattenSlice(u.createRoles) + u.allowedRoles = flattenSlice(u.allowedRoles) u.allowedLogins = flattenSlice(u.allowedLogins) u.allowedWindowsLogins = flattenSlice(u.allowedWindowsLogins) - // Validate roles - // DELETE IN 12.0.0 - for _, roleName := range u.createRoles { + // Validate roles (server does not do this yet). + for _, roleName := range u.allowedRoles { if _, err := client.GetRole(ctx, roleName); err != nil { return trace.Wrap(err) } @@ -227,7 +233,7 @@ func (u *UserCommand) Add(ctx context.Context, client auth.ClientI) error { } user.SetTraits(traits) - user.SetRoles(u.createRoles) + user.SetRoles(u.allowedRoles) if err := client.CreateUser(ctx, user); err != nil { return trace.Wrap(err) @@ -283,38 +289,69 @@ func printTokenAsText(token types.UserToken, messageFormat string) error { // Update updates existing user func (u *UserCommand) Update(ctx context.Context, client auth.ClientI) error { - if u.updateRoles == "" && u.updateLogins == "" { - return trace.BadParameter("Nothing to update. Please provide --set-roles or --set-logins flag.") - } user, err := client.GetUser(u.login, false) if err != nil { return trace.Wrap(err) } - var updateMessages []string - if u.updateRoles != "" { - roles := flattenSlice([]string{u.updateRoles}) + updateMessages := make(map[string][]string) + if len(u.allowedRoles) > 0 { + roles := flattenSlice(u.allowedRoles) for _, role := range roles { if _, err := client.GetRole(ctx, role); err != nil { return trace.Wrap(err) } } user.SetRoles(roles) - updateMessages = append(updateMessages, "with roles "+strings.Join(user.GetRoles(), ",")) + updateMessages["roles"] = roles + } + if len(u.allowedLogins) > 0 { + logins := flattenSlice(u.allowedLogins) + user.SetLogins(logins) + updateMessages["logins"] = logins + } + if len(u.allowedWindowsLogins) > 0 { + windowsLogins := flattenSlice(u.allowedWindowsLogins) + user.SetWindowsLogins(windowsLogins) + updateMessages["Windows logins"] = windowsLogins + } + if len(u.allowedKubeUsers) > 0 { + kubeUsers := flattenSlice(u.allowedKubeUsers) + user.SetKubeUsers(kubeUsers) + updateMessages["Kubernetes users"] = kubeUsers + } + if len(u.allowedKubeGroups) > 0 { + kubeGroups := flattenSlice(u.allowedKubeGroups) + user.SetKubeGroups(kubeGroups) + updateMessages["Kubernetes groups"] = kubeGroups + } + if len(u.allowedDatabaseUsers) > 0 { + dbUsers := flattenSlice(u.allowedDatabaseUsers) + user.SetDatabaseUsers(dbUsers) + updateMessages["database users"] = dbUsers + } + if len(u.allowedDatabaseNames) > 0 { + dbNames := flattenSlice(u.allowedDatabaseNames) + user.SetDatabaseNames(dbNames) + updateMessages["database names"] = dbNames + } + if len(u.allowedAWSRoleARNs) > 0 { + awsRoleARNs := flattenSlice(u.allowedAWSRoleARNs) + user.SetAWSRoleARNs(awsRoleARNs) + updateMessages["AWS role ARNs"] = awsRoleARNs } - if u.updateLogins != "" { - logins := flattenSlice([]string{u.updateLogins}) - traits := user.GetTraits() - traits[constants.TraitLogins] = logins - user.SetTraits(traits) - updateMessages = append(updateMessages, "with logins "+strings.Join(logins, ",")) + if len(updateMessages) == 0 { + return trace.BadParameter("Nothing to update. Please provide at least one --set flag.") } if err := client.UpsertUser(user); err != nil { return trace.Wrap(err) } - fmt.Printf("%v has been updated %v\n", user.GetName(), strings.Join(updateMessages, " and ")) + fmt.Printf("User %v has been updated:\n", user.GetName()) + for field, values := range updateMessages { + fmt.Printf("\tNew %v: %v\n", field, strings.Join(values, ",")) + } return nil } diff --git a/tool/tctl/common/user_command_test.go b/tool/tctl/common/user_command_test.go index 8d40b99854381..aacee3ee9612d 100644 --- a/tool/tctl/common/user_command_test.go +++ b/tool/tctl/common/user_command_test.go @@ -15,9 +15,14 @@ package common import ( + "context" "testing" "time" + "github.com/gravitational/teleport/api/constants" + "github.com/gravitational/teleport/api/types" + "github.com/gravitational/teleport/lib/config" + "github.com/gravitational/trace" "github.com/stretchr/testify/require" ) @@ -56,3 +61,120 @@ func TestTrimDurationSuffix(t *testing.T) { }) } } + +func TestUserUpdate(t *testing.T) { + fileConfig := &config.FileConfig{ + Global: config.Global{ + DataDir: t.TempDir(), + }, + Auth: config.Auth{ + Service: config.Service{ + EnabledFlag: "true", + ListenAddress: mustGetFreeLocalListenerAddr(t), + }, + }, + } + makeAndRunTestAuthServer(t, withFileConfig(fileConfig)) + ctx := context.Background() + client := getAuthClient(ctx, t, fileConfig) + + baseUser, err := types.NewUser("test-user") + require.NoError(t, err) + + tests := []struct { + name string + args []string + wantRoles []string + wantTraits map[string][]string + errorChecker func(error) bool + }{ + { + name: "no args", + errorChecker: trace.IsBadParameter, + }, + { + name: "new roles", + args: []string{"--set-roles", "editor,access"}, + wantRoles: []string{"editor", "access"}, + }, + { + name: "nonexistant roles", + args: []string{"--set-roles", "editor,access,fake"}, + errorChecker: trace.IsNotFound, + }, + { + name: "new logins", + args: []string{"--set-logins", "l1,l2,l3"}, + wantTraits: map[string][]string{ + constants.TraitLogins: {"l1", "l2", "l3"}, + }, + }, + { + name: "new windows logins", + args: []string{"--set-windows-logins", "w1,w2,w3"}, + wantTraits: map[string][]string{ + constants.TraitWindowsLogins: {"w1", "w2", "w3"}, + }, + }, + { + name: "new kube users", + args: []string{"--set-kubernetes-users", "k1,k2,k3"}, + wantTraits: map[string][]string{ + constants.TraitKubeUsers: {"k1", "k2", "k3"}, + }, + }, + { + name: "new kube groups", + args: []string{"--set-kubernetes-groups", "k4,k5,k6"}, + wantTraits: map[string][]string{ + constants.TraitKubeGroups: {"k4", "k5", "k6"}, + }, + }, + { + name: "new db users", + args: []string{"--set-db-users", "d1,d2,d3"}, + wantTraits: map[string][]string{ + constants.TraitDBUsers: {"d1", "d2", "d3"}, + }, + }, + { + name: "new db names", + args: []string{"--set-db-names", "d4,d5,d6"}, + wantTraits: map[string][]string{ + constants.TraitDBNames: {"d4", "d5", "d6"}, + }, + }, + { + name: "new AWS role ARNs", + args: []string{"--set-aws-role-arns", "a1,a2,a3"}, + wantTraits: map[string][]string{ + constants.TraitAWSRoleARNs: {"a1", "a2", "a3"}, + }, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + require.NoError(t, client.UpsertUser(baseUser)) + args := append([]string{"update"}, tc.args...) + args = append(args, "test-user") + err := runUserCommand(t, fileConfig, args) + if tc.errorChecker != nil { + require.True(t, tc.errorChecker(err), err) + return + } + + require.NoError(t, err) + updatedUser, err := client.GetUser("test-user", false) + require.NoError(t, err) + + if len(tc.wantRoles) > 0 { + require.Equal(t, tc.wantRoles, updatedUser.GetRoles()) + } + + for trait, values := range tc.wantTraits { + require.Equal(t, values, updatedUser.GetTraits()[trait]) + } + }) + } +}