Skip to content

Commit

Permalink
Add tests for motd fixes
Browse files Browse the repository at this point in the history
Part of this includes renaming export_test.go to export.go so I could
test the MOTD outside of lib/client/export.go
  • Loading branch information
Alex McGrath committed Mar 24, 2022
1 parent 81bd617 commit 6d9500b
Show file tree
Hide file tree
Showing 4 changed files with 79 additions and 10 deletions.
8 changes: 5 additions & 3 deletions lib/client/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -2226,7 +2226,7 @@ func (tc *TeleportClient) LogoutAll() error {
return nil
}

// PingAndShowMOTD pings the Teleport Proxy and displays the MOTD if it's available.
// PingAndShowMOTD pings the Teleport Proxy and displays the Message Of The Day if it's available.
func (tc *TeleportClient) PingAndShowMOTD(ctx context.Context) (*webclient.PingResponse, error) {
pr, err := tc.Ping(ctx)
if err != nil {
Expand Down Expand Up @@ -2441,7 +2441,7 @@ func (tc *TeleportClient) ShowMOTD(ctx context.Context) error {
// use might enter at the prompt. Whatever the user enters will
// be simply discarded, and the user can still CTRL+C out if they
// disagree.
_, err := passwordFromConsole()
_, err := passwordFromConsoleFn()
if err != nil {
return trace.Wrap(err)
}
Expand Down Expand Up @@ -2869,7 +2869,7 @@ func (tc *TeleportClient) AskOTP() (token string, err error) {
// AskPassword prompts the user to enter the password
func (tc *TeleportClient) AskPassword() (pwd string, err error) {
fmt.Printf("Enter password for Teleport user %v:\n", tc.Config.Username)
pwd, err = passwordFromConsole()
pwd, err = passwordFromConsoleFn()
if err != nil {
fmt.Fprintln(tc.Stderr, err)
return "", trace.Wrap(err)
Expand Down Expand Up @@ -2918,6 +2918,8 @@ func (tc *TeleportClient) getServerVersion(nodeClient *NodeClient) (string, erro
}
}

var passwordFromConsoleFn = passwordFromConsole

// passwordFromConsole reads from stdin without echoing typed characters to stdout
func passwordFromConsole() (string, error) {
// syscall.Stdin is not an int on windows. The linter will complain only on
Expand Down
18 changes: 18 additions & 0 deletions lib/client/export.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
// Copyright 2021 Gravitational, Inc
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package client

// PasswordFromConsoleFn exports passwordFromConsoleFn for tests.
var PasswordFromConsoleFn = &passwordFromConsoleFn
12 changes: 12 additions & 0 deletions tool/tsh/tsh.go
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,9 @@ type CLIConf struct {
// unsetEnvironment unsets Teleport related environment variables.
unsetEnvironment bool

// overrideStderr allows to switch standard error source for resource command. Used in tests.
overrideStderr io.Writer

// mockSSOLogin used in tests to override sso login handler in teleport client.
mockSSOLogin client.SSOLoginFunc

Expand All @@ -249,6 +252,14 @@ type CLIConf struct {
ConfigProxyTarget string
}

// Stderr returns the stderr writer.
func (c *CLIConf) Stderr() io.Writer {
if c.overrideStderr != nil {
return c.overrideStderr
}
return os.Stderr
}

func main() {
cmdLineOrig := os.Args[1:]
var cmdLine []string
Expand Down Expand Up @@ -1945,6 +1956,7 @@ func makeClient(cf *CLIConf, useProfileLogin bool) (*client.TeleportClient, erro
}
}

tc.Config.Stderr = cf.Stderr()
return tc, nil
}

Expand Down
51 changes: 44 additions & 7 deletions tool/tsh/tsh_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,15 @@ limitations under the License.
package main

import (
"bufio"
"bytes"
"context"
"fmt"
"io/ioutil"
"net"
"os"
"path/filepath"
"strings"
"testing"
"time"

Expand Down Expand Up @@ -154,7 +157,8 @@ func TestOIDCLogin(t *testing.T) {

connector := mockConnector(t)

authProcess, proxyProcess := makeTestServers(t, populist, dictator, connector, alice)
motd := "MESSAGE_OF_THE_DAY_OIDC"
authProcess, proxyProcess := makeTestServersWithMotd(t, motd, populist, dictator, connector, alice)

authServer := authProcess.GetAuthServer()
require.NotNil(t, authServer)
Expand Down Expand Up @@ -192,6 +196,8 @@ func TestOIDCLogin(t *testing.T) {
}
}()

buf := bytes.NewBuffer([]byte{})
sc := bufio.NewScanner(buf)
err = Run([]string{
"login",
"--insecure",
Expand All @@ -202,6 +208,7 @@ func TestOIDCLogin(t *testing.T) {
}, cliOption(func(cf *CLIConf) error {
cf.mockSSOLogin = mockSSOLogin(t, authServer, alice)
cf.SiteName = "localhost"
cf.overrideStderr = buf
return nil
}))

Expand All @@ -210,11 +217,22 @@ func TestOIDCLogin(t *testing.T) {
// verify that auto-request happened
require.True(t, didAutoRequest.Load())

findMOTD(t, sc, motd)
// if we got this far, then tsh successfully registered name change from `alice` to
// `[email protected]`, since the correct name needed to be used for the access
// request to be generated.
}

func findMOTD(t *testing.T, sc *bufio.Scanner, motd string) {
t.Helper()
for sc.Scan() {
if strings.Contains(sc.Text(), motd) {
return
}
}
require.Fail(t, "Failed to find %q MOTD in the logs", motd)
}

// TestLoginIdentityOut makes sure that "tsh login --out <ident>" command
// writes identity credentials to the specified path.
func TestLoginIdentityOut(t *testing.T) {
Expand Down Expand Up @@ -268,14 +286,17 @@ func TestRelogin(t *testing.T) {
require.NoError(t, err)
alice.SetRoles([]string{"admin"})

authProcess, proxyProcess := makeTestServers(t, connector, alice)
motd := "RELOGIN MOTD PRESENT"
authProcess, proxyProcess := makeTestServersWithMotd(t, motd, connector, alice)

authServer := authProcess.GetAuthServer()
require.NotNil(t, authServer)

proxyAddr, err := proxyProcess.ProxyWebAddr()
require.NoError(t, err)

buf := bytes.NewBuffer([]byte{})
sc := bufio.NewScanner(buf)
err = Run([]string{
"login",
"--insecure",
Expand All @@ -284,20 +305,30 @@ func TestRelogin(t *testing.T) {
"--proxy", proxyAddr.String(),
}, cliOption(func(cf *CLIConf) error {
cf.mockSSOLogin = mockSSOLogin(t, authServer, alice)
cf.overrideStderr = buf
return nil
}))
require.NoError(t, err)
findMOTD(t, sc, motd)

err = Run([]string{
"login",
"--insecure",
"--debug",
"--proxy", proxyAddr.String(),
"localhost",
})
}, cliOption(func(cf *CLIConf) error {
cf.overrideStderr = buf
return nil
}))
require.NoError(t, err)
findMOTD(t, sc, motd)

err = Run([]string{"logout"})
err = Run([]string{"logout"},
cliOption(func(cf *CLIConf) error {
cf.overrideStderr = buf
return nil
}))
require.NoError(t, err)

err = Run([]string{
Expand All @@ -309,8 +340,10 @@ func TestRelogin(t *testing.T) {
"localhost",
}, cliOption(func(cf *CLIConf) error {
cf.mockSSOLogin = mockSSOLogin(t, authServer, alice)
cf.overrideStderr = buf
return nil
}))
findMOTD(t, sc, motd)
require.NoError(t, err)
}

Expand Down Expand Up @@ -480,7 +513,7 @@ func TestAccessRequestOnLeaf(t *testing.T) {
})
require.NoError(t, err)

leafAuth, _ := makeTestServersWithName(t, "leafcluster")
leafAuth, _ := makeTestServersWithName(t, "leafcluster", "")
tryCreateTrustedCluster(t, leafAuth.GetAuthServer(), trustedCluster)

err = Run([]string{
Expand Down Expand Up @@ -1071,10 +1104,14 @@ func TestKubeConfigUpdate(t *testing.T) {
}

func makeTestServers(t *testing.T, bootstrap ...types.Resource) (auth *service.TeleportProcess, proxy *service.TeleportProcess) {
return makeTestServersWithName(t, "", bootstrap...)
return makeTestServersWithName(t, "", "", bootstrap...)
}

func makeTestServersWithMotd(t *testing.T, motd string, bootstrap ...types.Resource) (auth *service.TeleportProcess, proxy *service.TeleportProcess) {
return makeTestServersWithName(t, "", motd, bootstrap...)
}

func makeTestServersWithName(t *testing.T, name string, bootstrap ...types.Resource) (auth *service.TeleportProcess, proxy *service.TeleportProcess) {
func makeTestServersWithName(t *testing.T, name string, motd string, bootstrap ...types.Resource) (auth *service.TeleportProcess, proxy *service.TeleportProcess) {
var err error
// Set up a test auth server.
//
Expand Down

0 comments on commit 6d9500b

Please sign in to comment.