Skip to content

Commit

Permalink
[v8] Backport #10735: Fix MOTD not showing up on tsh login with certa…
Browse files Browse the repository at this point in the history
…in arguments (#11371)

* Fix MOTD not showing up on tsh login with certain arguments

- changes to configuration.go: fixes tsh login in first test case
  `tsh login --insecure --proxy=127.0.0.1:3080 --user=test`
- changes to apiserver.go fixes `--auth` not showing motd

* Add tests for motd fixes

Part of this includes renaming export_test.go to export.go so I could
test the MOTD outside of lib/client/export.go
  • Loading branch information
Alex McGrath authored Mar 28, 2022
1 parent c733c69 commit 2abd7ea
Show file tree
Hide file tree
Showing 6 changed files with 123 additions and 21 deletions.
29 changes: 19 additions & 10 deletions lib/client/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -2335,14 +2335,8 @@ func (tc *TeleportClient) LogoutAll() error {
return nil
}

// Login logs the user into a Teleport cluster by talking to a Teleport proxy.
//
// The returned Key should typically be passed to ActivateKey in order to
// update local agent state.
//
func (tc *TeleportClient) Login(ctx context.Context) (*Key, error) {
// Ping the endpoint to see if it's up and find the type of authentication
// supported.
// PingAndShowMOTD pings the Teleport Proxy and displays the Message Of The Day if it's available.
func (tc *TeleportClient) PingAndShowMOTD(ctx context.Context) (*webclient.PingResponse, error) {
pr, err := tc.Ping(ctx)
if err != nil {
return nil, trace.Wrap(err)
Expand All @@ -2354,6 +2348,21 @@ func (tc *TeleportClient) Login(ctx context.Context) (*Key, error) {
return nil, trace.Wrap(err)
}
}
return pr, nil
}

// Login logs the user into a Teleport cluster by talking to a Teleport proxy.
//
// The returned Key should typically be passed to ActivateKey in order to
// update local agent state.
//
func (tc *TeleportClient) Login(ctx context.Context) (*Key, error) {
// Ping the endpoint to see if it's up and find the type of authentication
// supported, also show the message of the day if available.
pr, err := tc.PingAndShowMOTD(ctx)
if err != nil {
return nil, trace.Wrap(err)
}

// generate a new keypair. the public key will be signed via proxy if client's
// password+OTP are valid
Expand Down Expand Up @@ -2539,13 +2548,13 @@ func (tc *TeleportClient) ShowMOTD(ctx context.Context) error {
}

if motd.Text != "" {
fmt.Printf("%s\nPress [ENTER] to continue.\n", motd.Text)
fmt.Fprintf(tc.Stderr, "%s\nPress [ENTER] to continue.\n", motd.Text)
// We're re-using the password reader for user acknowledgment for
// aesthetic purposes, because we want to hide any garbage the
// use might enter at the prompt. Whatever the user enters will
// be simply discarded, and the user can still CTRL+C out if they
// disagree.
_, err := passwordFromConsole()
_, err := passwordFromConsoleFn()
if err != nil {
return trace.Wrap(err)
}
Expand Down
File renamed without changes.
3 changes: 3 additions & 0 deletions lib/config/configuration.go
Original file line number Diff line number Diff line change
Expand Up @@ -539,6 +539,9 @@ func applyAuthConfig(fc *FileConfig, cfg *service.Config) error {
if err != nil {
return trace.Wrap(err)
}
}

if fc.Auth.MessageOfTheDay != "" {
cfg.Auth.Preference.SetMessageOfTheDay(fc.Auth.MessageOfTheDay)
}

Expand Down
8 changes: 6 additions & 2 deletions lib/web/apiserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -791,33 +791,37 @@ func (h *Handler) pingWithConnector(w http.ResponseWriter, r *http.Request, p ht
ServerVersion: teleport.Version,
}

hasMessageOfTheDay := cap.GetMessageOfTheDay() != ""
if connectorName == constants.Local {
as, err := localSettings(cap)
response.Auth, err = localSettings(cap)
if err != nil {
return nil, trace.Wrap(err)
}
response.Auth = as
response.Auth.HasMessageOfTheDay = hasMessageOfTheDay
return response, nil
}

// first look for a oidc connector with that name
oidcConnector, err := authClient.GetOIDCConnector(r.Context(), connectorName, false)
if err == nil {
response.Auth = oidcSettings(oidcConnector, cap)
response.Auth.HasMessageOfTheDay = hasMessageOfTheDay
return response, nil
}

// if no oidc connector was found, look for a saml connector
samlConnector, err := authClient.GetSAMLConnector(r.Context(), connectorName, false)
if err == nil {
response.Auth = samlSettings(samlConnector, cap)
response.Auth.HasMessageOfTheDay = hasMessageOfTheDay
return response, nil
}

// look for github connector
githubConnector, err := authClient.GetGithubConnector(r.Context(), connectorName, false)
if err == nil {
response.Auth = githubSettings(githubConnector, cap)
response.Auth.HasMessageOfTheDay = hasMessageOfTheDay
return response, nil
}

Expand Down
31 changes: 31 additions & 0 deletions tool/tsh/tsh.go
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,8 @@ type CLIConf struct {

// overrideStdout allows to switch standard output source for resource command. Used in tests.
overrideStdout io.Writer
// overrideStderr allows to switch standard error source for resource command. Used in tests.
overrideStderr io.Writer

// mockSSOLogin used in tests to override sso login handler in teleport client.
mockSSOLogin client.SSOLoginFunc
Expand Down Expand Up @@ -284,6 +286,14 @@ func (c *CLIConf) Stdout() io.Writer {
return os.Stdout
}

// Stderr returns the stderr writer.
func (c *CLIConf) Stderr() io.Writer {
if c.overrideStderr != nil {
return c.overrideStderr
}
return os.Stderr
}

func main() {
cmdLineOrig := os.Args[1:]
var cmdLine []string
Expand Down Expand Up @@ -844,23 +854,37 @@ func onLogin(cf *CLIConf) error {
// in case if nothing is specified, re-fetch kube clusters and print
// current status
case cf.Proxy == "" && cf.SiteName == "" && cf.DesiredRoles == "" && cf.RequestID == "" && cf.IdentityFileOut == "":
_, err := tc.PingAndShowMOTD(cf.Context)
if err != nil {
return trace.Wrap(err)
}
if err := updateKubeConfig(cf, tc, ""); err != nil {
return trace.Wrap(err)
}
printProfiles(cf.Debug, profile, profiles)

return nil
// in case if parameters match, re-fetch kube clusters and print
// current status
case host(cf.Proxy) == host(profile.ProxyURL.Host) && cf.SiteName == profile.Cluster && cf.DesiredRoles == "" && cf.RequestID == "":
_, err := tc.PingAndShowMOTD(cf.Context)
if err != nil {
return trace.Wrap(err)
}
if err := updateKubeConfig(cf, tc, ""); err != nil {
return trace.Wrap(err)
}
printProfiles(cf.Debug, profile, profiles)

return nil
// proxy is unspecified or the same as the currently provided proxy,
// but cluster is specified, treat this as selecting a new cluster
// for the same proxy
case (cf.Proxy == "" || host(cf.Proxy) == host(profile.ProxyURL.Host)) && cf.SiteName != "":
_, err := tc.PingAndShowMOTD(cf.Context)
if err != nil {
return trace.Wrap(err)
}
// trigger reissue, preserving any active requests.
err = tc.ReissueUserCerts(cf.Context, client.CertCacheKeep, client.ReissueParams{
AccessRequests: profile.ActiveRequests.AccessRequests,
Expand All @@ -875,11 +899,16 @@ func onLogin(cf *CLIConf) error {
if err := updateKubeConfig(cf, tc, ""); err != nil {
return trace.Wrap(err)
}

return trace.Wrap(onStatus(cf))
// proxy is unspecified or the same as the currently provided proxy,
// but desired roles or request ID is specified, treat this as a
// privilege escalation request for the same login session.
case (cf.Proxy == "" || host(cf.Proxy) == host(profile.ProxyURL.Host)) && (cf.DesiredRoles != "" || cf.RequestID != "") && cf.IdentityFileOut == "":
_, err := tc.PingAndShowMOTD(cf.Context)
if err != nil {
return trace.Wrap(err)
}
if err := executeAccessRequest(cf, tc); err != nil {
return trace.Wrap(err)
}
Expand Down Expand Up @@ -2067,6 +2096,8 @@ func makeClient(cf *CLIConf, useProfileLogin bool) (*client.TeleportClient, erro
}
}

tc.Config.Stderr = cf.Stderr()
tc.Config.Stdout = cf.Stdout()
return tc, nil
}

Expand Down
73 changes: 64 additions & 9 deletions tool/tsh/tsh_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ limitations under the License.
package main

import (
"bufio"
"bytes"
"context"
"fmt"
"io/ioutil"
Expand Down Expand Up @@ -175,7 +177,11 @@ func TestOIDCLogin(t *testing.T) {

connector := mockConnector(t)

authProcess, proxyProcess := makeTestServers(t, withBootstrap(populist, dictator, connector, alice))
motd := "MESSAGE_OF_THE_DAY_OIDC"
authProcess, proxyProcess := makeTestServers(t,
withBootstrap(populist, dictator, connector, alice),
withMOTD(t, motd),
)

authServer := authProcess.GetAuthServer()
require.NotNil(t, authServer)
Expand Down Expand Up @@ -213,6 +219,8 @@ func TestOIDCLogin(t *testing.T) {
}
}()

buf := bytes.NewBuffer([]byte{})
sc := bufio.NewScanner(buf)
err = Run([]string{
"login",
"--insecure",
Expand All @@ -223,6 +231,7 @@ func TestOIDCLogin(t *testing.T) {
}, setHomePath(tmpHomePath), cliOption(func(cf *CLIConf) error {
cf.mockSSOLogin = mockSSOLogin(t, authServer, alice)
cf.SiteName = "localhost"
cf.overrideStderr = buf
return nil
}))

Expand All @@ -231,11 +240,22 @@ func TestOIDCLogin(t *testing.T) {
// verify that auto-request happened
require.True(t, didAutoRequest.Load())

findMOTD(t, sc, motd)
// if we got this far, then tsh successfully registered name change from `alice` to
// `[email protected]`, since the correct name needed to be used for the access
// request to be generated.
}

func findMOTD(t *testing.T, sc *bufio.Scanner, motd string) {
t.Helper()
for sc.Scan() {
if strings.Contains(sc.Text(), motd) {
return
}
}
require.Fail(t, "Failed to find %q MOTD in the logs", motd)
}

// TestLoginIdentityOut makes sure that "tsh login --out <ident>" command
// writes identity credentials to the specified path.
func TestLoginIdentityOut(t *testing.T) {
Expand Down Expand Up @@ -283,14 +303,20 @@ func TestRelogin(t *testing.T) {
require.NoError(t, err)
alice.SetRoles([]string{"access"})

authProcess, proxyProcess := makeTestServers(t, withBootstrap(connector, alice))
motd := "RELOGIN MOTD PRESENT"
authProcess, proxyProcess := makeTestServers(t,
withBootstrap(connector, alice),
withMOTD(t, motd),
)

authServer := authProcess.GetAuthServer()
require.NotNil(t, authServer)

proxyAddr, err := proxyProcess.ProxyWebAddr()
require.NoError(t, err)

buf := bytes.NewBuffer([]byte{})
sc := bufio.NewScanner(buf)
err = Run([]string{
"login",
"--insecure",
Expand All @@ -299,20 +325,32 @@ func TestRelogin(t *testing.T) {
"--proxy", proxyAddr.String(),
}, setHomePath(tmpHomePath), cliOption(func(cf *CLIConf) error {
cf.mockSSOLogin = mockSSOLogin(t, authServer, alice)
cf.overrideStderr = buf
return nil
}))
require.NoError(t, err)
findMOTD(t, sc, motd)

err = Run([]string{
"login",
"--insecure",
"--debug",
"--proxy", proxyAddr.String(),
"localhost",
}, setHomePath(tmpHomePath))
}, setHomePath(tmpHomePath),
cliOption(func(cf *CLIConf) error {
cf.mockSSOLogin = mockSSOLogin(t, authServer, alice)
cf.overrideStderr = buf
return nil
}))
require.NoError(t, err)
findMOTD(t, sc, motd)

err = Run([]string{"logout"}, setHomePath(tmpHomePath))
err = Run([]string{"logout"}, setHomePath(tmpHomePath),
cliOption(func(cf *CLIConf) error {
cf.overrideStderr = buf
return nil
}))
require.NoError(t, err)

err = Run([]string{
Expand All @@ -324,8 +362,10 @@ func TestRelogin(t *testing.T) {
"localhost",
}, setHomePath(tmpHomePath), cliOption(func(cf *CLIConf) error {
cf.mockSSOLogin = mockSSOLogin(t, authServer, alice)
cf.overrideStderr = buf
return nil
}))
findMOTD(t, sc, motd)
require.NoError(t, err)
}

Expand Down Expand Up @@ -1247,8 +1287,8 @@ func TestMakeTableWithTruncatedColumn(t *testing.T) {
}

type testServersOpts struct {
bootstrap []types.Resource
authConfigFunc func(cfg *service.AuthConfig)
bootstrap []types.Resource
authConfigFuncs []func(cfg *service.AuthConfig)
}

type testServerOptFunc func(o *testServersOpts)
Expand All @@ -1261,7 +1301,11 @@ func withBootstrap(bootstrap ...types.Resource) testServerOptFunc {

func withAuthConfig(fn func(cfg *service.AuthConfig)) testServerOptFunc {
return func(o *testServersOpts) {
o.authConfigFunc = fn
if o.authConfigFuncs == nil {
o.authConfigFuncs = []func(cfg *service.AuthConfig){}
}

o.authConfigFuncs = append(o.authConfigFuncs, fn)
}
}

Expand All @@ -1276,6 +1320,17 @@ func withClusterName(t *testing.T, n string) testServerOptFunc {
})
}

func withMOTD(t *testing.T, motd string) testServerOptFunc {
oldpass := client.PasswordFromConsoleFn
*client.PasswordFromConsoleFn = func() (string, error) {
return "", nil
}
t.Cleanup(func() { *client.PasswordFromConsoleFn = *oldpass })
return withAuthConfig(func(cfg *service.AuthConfig) {
cfg.Preference.SetMessageOfTheDay(motd)
})
}

func makeTestServers(t *testing.T, opts ...testServerOptFunc) (auth *service.TeleportProcess, proxy *service.TeleportProcess) {
var options testServersOpts
for _, opt := range opts {
Expand Down Expand Up @@ -1308,8 +1363,8 @@ func makeTestServers(t *testing.T, opts ...testServerOptFunc) (auth *service.Tel
cfg.Proxy.Enabled = false
cfg.Log = utils.NewLoggerForTests()

if options.authConfigFunc != nil {
options.authConfigFunc(&cfg.Auth)
for _, fn := range options.authConfigFuncs {
fn(&cfg.Auth)
}

auth, err = service.NewTeleport(cfg)
Expand Down

0 comments on commit 2abd7ea

Please sign in to comment.