Skip to content

Commit

Permalink
Add tests for motd fixes
Browse files Browse the repository at this point in the history
Part of this includes renaming export_test.go to export.go so I could
test the MOTD outside of lib/client/export.go
  • Loading branch information
Alex McGrath committed Mar 15, 2022
1 parent 5d56de2 commit 50e7db8
Show file tree
Hide file tree
Showing 4 changed files with 79 additions and 11 deletions.
4 changes: 2 additions & 2 deletions lib/client/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -2437,7 +2437,7 @@ func (tc *TeleportClient) LogoutAll() error {
return nil
}

// PingAndShowMOTD pings the Teleport Proxy and displays the MOTD if it's available.
// PingAndShowMOTD pings the Teleport Proxy and displays the Message Of The Day if it's available.
func (tc *TeleportClient) PingAndShowMOTD(ctx context.Context) (*webclient.PingResponse, error) {
pr, err := tc.Ping(ctx)
if err != nil {
Expand Down Expand Up @@ -2652,7 +2652,7 @@ func (tc *TeleportClient) ShowMOTD(ctx context.Context) error {
// use might enter at the prompt. Whatever the user enters will
// be simply discarded, and the user can still CTRL+C out if they
// disagree.
_, err := passwordFromConsole()
_, err := passwordFromConsoleFn()
if err != nil {
return trace.Wrap(err)
}
Expand Down
File renamed without changes.
13 changes: 13 additions & 0 deletions tool/tsh/tsh.go
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,8 @@ type CLIConf struct {

// overrideStdout allows to switch standard output source for resource command. Used in tests.
overrideStdout io.Writer
// overrideStderr allows to switch standard error source for resource command. Used in tests.
overrideStderr io.Writer

// mockSSOLogin used in tests to override sso login handler in teleport client.
mockSSOLogin client.SSOLoginFunc
Expand Down Expand Up @@ -302,6 +304,14 @@ func (c *CLIConf) Stdout() io.Writer {
return os.Stdout
}

// Stderr returns the stderr writer.
func (c *CLIConf) Stderr() io.Writer {
if c.overrideStderr != nil {
return c.overrideStderr
}
return os.Stderr
}

func main() {
cmdLineOrig := os.Args[1:]
var cmdLine []string
Expand Down Expand Up @@ -2131,6 +2141,9 @@ func makeClient(cf *CLIConf, useProfileLogin bool) (*client.TeleportClient, erro
}
}

tc.Config.Stderr = cf.Stderr()
tc.Config.Stdout = cf.Stdout()

tc.Config.Reason = cf.Reason
tc.Config.Invited = cf.Invited
tc.Config.DisplayParticipantRequirements = cf.displayParticipantRequirements
Expand Down
73 changes: 64 additions & 9 deletions tool/tsh/tsh_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ limitations under the License.
package main

import (
"bufio"
"bytes"
"context"
"fmt"
"io/ioutil"
Expand Down Expand Up @@ -175,7 +177,11 @@ func TestOIDCLogin(t *testing.T) {

connector := mockConnector(t)

authProcess, proxyProcess := makeTestServers(t, withBootstrap(populist, dictator, connector, alice))
motd := "MESSAGE_OF_THE_DAY_OIDC"
authProcess, proxyProcess := makeTestServers(t,
withBootstrap(populist, dictator, connector, alice),
withMOTD(t, motd),
)

authServer := authProcess.GetAuthServer()
require.NotNil(t, authServer)
Expand Down Expand Up @@ -213,6 +219,8 @@ func TestOIDCLogin(t *testing.T) {
}
}()

buf := bytes.NewBuffer([]byte{})
sc := bufio.NewScanner(buf)
err = Run([]string{
"login",
"--insecure",
Expand All @@ -223,6 +231,7 @@ func TestOIDCLogin(t *testing.T) {
}, setHomePath(tmpHomePath), cliOption(func(cf *CLIConf) error {
cf.mockSSOLogin = mockSSOLogin(t, authServer, alice)
cf.SiteName = "localhost"
cf.overrideStderr = buf
return nil
}))

Expand All @@ -231,11 +240,22 @@ func TestOIDCLogin(t *testing.T) {
// verify that auto-request happened
require.True(t, didAutoRequest.Load())

findMOTD(t, sc, motd)
// if we got this far, then tsh successfully registered name change from `alice` to
// `[email protected]`, since the correct name needed to be used for the access
// request to be generated.
}

func findMOTD(t *testing.T, sc *bufio.Scanner, motd string) {
t.Helper()
for sc.Scan() {
if strings.Contains(sc.Text(), motd) {
return
}
}
require.Fail(t, "Failed to find %q MOTD in the logs", motd)
}

// TestLoginIdentityOut makes sure that "tsh login --out <ident>" command
// writes identity credentials to the specified path.
func TestLoginIdentityOut(t *testing.T) {
Expand Down Expand Up @@ -283,14 +303,20 @@ func TestRelogin(t *testing.T) {
require.NoError(t, err)
alice.SetRoles([]string{"access"})

authProcess, proxyProcess := makeTestServers(t, withBootstrap(connector, alice))
motd := "RELOGIN MOTD PRESENT"
authProcess, proxyProcess := makeTestServers(t,
withBootstrap(connector, alice),
withMOTD(t, motd),
)

authServer := authProcess.GetAuthServer()
require.NotNil(t, authServer)

proxyAddr, err := proxyProcess.ProxyWebAddr()
require.NoError(t, err)

buf := bytes.NewBuffer([]byte{})
sc := bufio.NewScanner(buf)
err = Run([]string{
"login",
"--insecure",
Expand All @@ -299,20 +325,32 @@ func TestRelogin(t *testing.T) {
"--proxy", proxyAddr.String(),
}, setHomePath(tmpHomePath), cliOption(func(cf *CLIConf) error {
cf.mockSSOLogin = mockSSOLogin(t, authServer, alice)
cf.overrideStderr = buf
return nil
}))
require.NoError(t, err)
findMOTD(t, sc, motd)

err = Run([]string{
"login",
"--insecure",
"--debug",
"--proxy", proxyAddr.String(),
"localhost",
}, setHomePath(tmpHomePath))
}, setHomePath(tmpHomePath),
cliOption(func(cf *CLIConf) error {
cf.mockSSOLogin = mockSSOLogin(t, authServer, alice)
cf.overrideStderr = buf
return nil
}))
require.NoError(t, err)
findMOTD(t, sc, motd)

err = Run([]string{"logout"}, setHomePath(tmpHomePath))
err = Run([]string{"logout"}, setHomePath(tmpHomePath),
cliOption(func(cf *CLIConf) error {
cf.overrideStderr = buf
return nil
}))
require.NoError(t, err)

err = Run([]string{
Expand All @@ -324,8 +362,10 @@ func TestRelogin(t *testing.T) {
"localhost",
}, setHomePath(tmpHomePath), cliOption(func(cf *CLIConf) error {
cf.mockSSOLogin = mockSSOLogin(t, authServer, alice)
cf.overrideStderr = buf
return nil
}))
findMOTD(t, sc, motd)
require.NoError(t, err)
}

Expand Down Expand Up @@ -1238,8 +1278,8 @@ func TestSetX11Config(t *testing.T) {
}

type testServersOpts struct {
bootstrap []types.Resource
authConfigFunc func(cfg *service.AuthConfig)
bootstrap []types.Resource
authConfigFuncs []func(cfg *service.AuthConfig)
}

type testServerOptFunc func(o *testServersOpts)
Expand All @@ -1252,7 +1292,11 @@ func withBootstrap(bootstrap ...types.Resource) testServerOptFunc {

func withAuthConfig(fn func(cfg *service.AuthConfig)) testServerOptFunc {
return func(o *testServersOpts) {
o.authConfigFunc = fn
if o.authConfigFuncs == nil {
o.authConfigFuncs = []func(cfg *service.AuthConfig){}
}

o.authConfigFuncs = append(o.authConfigFuncs, fn)
}
}

Expand All @@ -1267,6 +1311,17 @@ func withClusterName(t *testing.T, n string) testServerOptFunc {
})
}

func withMOTD(t *testing.T, motd string) testServerOptFunc {
oldpass := client.PasswordFromConsoleFn
*client.PasswordFromConsoleFn = func() (string, error) {
return "", nil
}
t.Cleanup(func() { *client.PasswordFromConsoleFn = *oldpass })
return withAuthConfig(func(cfg *service.AuthConfig) {
cfg.Preference.SetMessageOfTheDay(motd)
})
}

func makeTestServers(t *testing.T, opts ...testServerOptFunc) (auth *service.TeleportProcess, proxy *service.TeleportProcess) {
var options testServersOpts
for _, opt := range opts {
Expand Down Expand Up @@ -1299,8 +1354,8 @@ func makeTestServers(t *testing.T, opts ...testServerOptFunc) (auth *service.Tel
cfg.Proxy.Enabled = false
cfg.Log = utils.NewLoggerForTests()

if options.authConfigFunc != nil {
options.authConfigFunc(&cfg.Auth)
for _, fn := range options.authConfigFuncs {
fn(&cfg.Auth)
}

auth, err = service.NewTeleport(cfg)
Expand Down

0 comments on commit 50e7db8

Please sign in to comment.