Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix MOTD not showing up on tsh login with certain arguments #10735

Merged
merged 2 commits into from
Mar 15, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 18 additions & 9 deletions lib/client/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -2437,14 +2437,8 @@ func (tc *TeleportClient) LogoutAll() error {
return nil
}

// Login logs the user into a Teleport cluster by talking to a Teleport proxy.
//
// The returned Key should typically be passed to ActivateKey in order to
// update local agent state.
//
func (tc *TeleportClient) Login(ctx context.Context) (*Key, error) {
// Ping the endpoint to see if it's up and find the type of authentication
// supported.
// PingAndShowMOTD pings the Teleport Proxy and displays the Message Of The Day if it's available.
func (tc *TeleportClient) PingAndShowMOTD(ctx context.Context) (*webclient.PingResponse, error) {
pr, err := tc.Ping(ctx)
if err != nil {
return nil, trace.Wrap(err)
Expand All @@ -2456,6 +2450,21 @@ func (tc *TeleportClient) Login(ctx context.Context) (*Key, error) {
return nil, trace.Wrap(err)
}
}
return pr, nil
}

// Login logs the user into a Teleport cluster by talking to a Teleport proxy.
//
// The returned Key should typically be passed to ActivateKey in order to
// update local agent state.
//
func (tc *TeleportClient) Login(ctx context.Context) (*Key, error) {
// Ping the endpoint to see if it's up and find the type of authentication
// supported, also show the message of the day if available.
pr, err := tc.PingAndShowMOTD(ctx)
if err != nil {
return nil, trace.Wrap(err)
}

// generate a new keypair. the public key will be signed via proxy if client's
// password+OTP are valid
Expand Down Expand Up @@ -2643,7 +2652,7 @@ func (tc *TeleportClient) ShowMOTD(ctx context.Context) error {
// use might enter at the prompt. Whatever the user enters will
// be simply discarded, and the user can still CTRL+C out if they
// disagree.
_, err := passwordFromConsole()
_, err := passwordFromConsoleFn()
if err != nil {
return trace.Wrap(err)
}
Expand Down
File renamed without changes.
3 changes: 3 additions & 0 deletions lib/config/configuration.go
Original file line number Diff line number Diff line change
Expand Up @@ -547,6 +547,9 @@ func applyAuthConfig(fc *FileConfig, cfg *service.Config) error {
if err != nil {
return trace.Wrap(err)
}
}

if fc.Auth.MessageOfTheDay != "" {
cfg.Auth.Preference.SetMessageOfTheDay(fc.Auth.MessageOfTheDay)
}

Expand Down
8 changes: 6 additions & 2 deletions lib/web/apiserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -812,33 +812,37 @@ func (h *Handler) pingWithConnector(w http.ResponseWriter, r *http.Request, p ht
ServerVersion: teleport.Version,
}

hasMessageOfTheDay := cap.GetMessageOfTheDay() != ""
if connectorName == constants.Local {
as, err := localSettings(cap)
response.Auth, err = localSettings(cap)
if err != nil {
return nil, trace.Wrap(err)
}
response.Auth = as
response.Auth.HasMessageOfTheDay = hasMessageOfTheDay
return response, nil
}

// first look for a oidc connector with that name
oidcConnector, err := authClient.GetOIDCConnector(r.Context(), connectorName, false)
if err == nil {
response.Auth = oidcSettings(oidcConnector, cap)
response.Auth.HasMessageOfTheDay = hasMessageOfTheDay
return response, nil
}

// if no oidc connector was found, look for a saml connector
samlConnector, err := authClient.GetSAMLConnector(r.Context(), connectorName, false)
lxea marked this conversation as resolved.
Show resolved Hide resolved
if err == nil {
response.Auth = samlSettings(samlConnector, cap)
response.Auth.HasMessageOfTheDay = hasMessageOfTheDay
return response, nil
}

// look for github connector
githubConnector, err := authClient.GetGithubConnector(r.Context(), connectorName, false)
if err == nil {
response.Auth = githubSettings(githubConnector, cap)
response.Auth.HasMessageOfTheDay = hasMessageOfTheDay
return response, nil
}

Expand Down
32 changes: 32 additions & 0 deletions tool/tsh/tsh.go
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,8 @@ type CLIConf struct {

// overrideStdout allows to switch standard output source for resource command. Used in tests.
overrideStdout io.Writer
// overrideStderr allows to switch standard error source for resource command. Used in tests.
overrideStderr io.Writer

// mockSSOLogin used in tests to override sso login handler in teleport client.
mockSSOLogin client.SSOLoginFunc
Expand Down Expand Up @@ -302,6 +304,14 @@ func (c *CLIConf) Stdout() io.Writer {
return os.Stdout
}

// Stderr returns the stderr writer.
func (c *CLIConf) Stderr() io.Writer {
if c.overrideStderr != nil {
return c.overrideStderr
}
return os.Stderr
}

func main() {
cmdLineOrig := os.Args[1:]
var cmdLine []string
Expand Down Expand Up @@ -880,23 +890,37 @@ func onLogin(cf *CLIConf) error {
// in case if nothing is specified, re-fetch kube clusters and print
// current status
case cf.Proxy == "" && cf.SiteName == "" && cf.DesiredRoles == "" && cf.RequestID == "" && cf.IdentityFileOut == "":
_, err := tc.PingAndShowMOTD(cf.Context)
lxea marked this conversation as resolved.
Show resolved Hide resolved
if err != nil {
return trace.Wrap(err)
}
if err := updateKubeConfig(cf, tc, ""); err != nil {
return trace.Wrap(err)
}
printProfiles(cf.Debug, profile, profiles)

return nil
// in case if parameters match, re-fetch kube clusters and print
// current status
case host(cf.Proxy) == host(profile.ProxyURL.Host) && cf.SiteName == profile.Cluster && cf.DesiredRoles == "" && cf.RequestID == "":
_, err := tc.PingAndShowMOTD(cf.Context)
if err != nil {
return trace.Wrap(err)
}
if err := updateKubeConfig(cf, tc, ""); err != nil {
return trace.Wrap(err)
}
printProfiles(cf.Debug, profile, profiles)

return nil
// proxy is unspecified or the same as the currently provided proxy,
// but cluster is specified, treat this as selecting a new cluster
// for the same proxy
case (cf.Proxy == "" || host(cf.Proxy) == host(profile.ProxyURL.Host)) && cf.SiteName != "":
_, err := tc.PingAndShowMOTD(cf.Context)
if err != nil {
return trace.Wrap(err)
}
// trigger reissue, preserving any active requests.
err = tc.ReissueUserCerts(cf.Context, client.CertCacheKeep, client.ReissueParams{
AccessRequests: profile.ActiveRequests.AccessRequests,
Expand All @@ -911,11 +935,16 @@ func onLogin(cf *CLIConf) error {
if err := updateKubeConfig(cf, tc, ""); err != nil {
return trace.Wrap(err)
}

return trace.Wrap(onStatus(cf))
// proxy is unspecified or the same as the currently provided proxy,
// but desired roles or request ID is specified, treat this as a
// privilege escalation request for the same login session.
case (cf.Proxy == "" || host(cf.Proxy) == host(profile.ProxyURL.Host)) && (cf.DesiredRoles != "" || cf.RequestID != "") && cf.IdentityFileOut == "":
_, err := tc.PingAndShowMOTD(cf.Context)
if err != nil {
return trace.Wrap(err)
}
if err := executeAccessRequest(cf, tc); err != nil {
return trace.Wrap(err)
}
Expand Down Expand Up @@ -2112,6 +2141,9 @@ func makeClient(cf *CLIConf, useProfileLogin bool) (*client.TeleportClient, erro
}
}

tc.Config.Stderr = cf.Stderr()
tc.Config.Stdout = cf.Stdout()

tc.Config.Reason = cf.Reason
tc.Config.Invited = cf.Invited
tc.Config.DisplayParticipantRequirements = cf.displayParticipantRequirements
Expand Down
73 changes: 64 additions & 9 deletions tool/tsh/tsh_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ limitations under the License.
package main

import (
"bufio"
"bytes"
"context"
"fmt"
"io/ioutil"
Expand Down Expand Up @@ -175,7 +177,11 @@ func TestOIDCLogin(t *testing.T) {

connector := mockConnector(t)

authProcess, proxyProcess := makeTestServers(t, withBootstrap(populist, dictator, connector, alice))
motd := "MESSAGE_OF_THE_DAY_OIDC"
authProcess, proxyProcess := makeTestServers(t,
withBootstrap(populist, dictator, connector, alice),
withMOTD(t, motd),
)

authServer := authProcess.GetAuthServer()
require.NotNil(t, authServer)
Expand Down Expand Up @@ -213,6 +219,8 @@ func TestOIDCLogin(t *testing.T) {
}
}()

buf := bytes.NewBuffer([]byte{})
sc := bufio.NewScanner(buf)
err = Run([]string{
"login",
"--insecure",
Expand All @@ -223,6 +231,7 @@ func TestOIDCLogin(t *testing.T) {
}, setHomePath(tmpHomePath), cliOption(func(cf *CLIConf) error {
cf.mockSSOLogin = mockSSOLogin(t, authServer, alice)
cf.SiteName = "localhost"
cf.overrideStderr = buf
return nil
}))

Expand All @@ -231,11 +240,22 @@ func TestOIDCLogin(t *testing.T) {
// verify that auto-request happened
require.True(t, didAutoRequest.Load())

findMOTD(t, sc, motd)
// if we got this far, then tsh successfully registered name change from `alice` to
// `[email protected]`, since the correct name needed to be used for the access
// request to be generated.
}

func findMOTD(t *testing.T, sc *bufio.Scanner, motd string) {
t.Helper()
for sc.Scan() {
if strings.Contains(sc.Text(), motd) {
return
}
}
require.Fail(t, "Failed to find %q MOTD in the logs", motd)
}

// TestLoginIdentityOut makes sure that "tsh login --out <ident>" command
// writes identity credentials to the specified path.
func TestLoginIdentityOut(t *testing.T) {
Expand Down Expand Up @@ -283,14 +303,20 @@ func TestRelogin(t *testing.T) {
require.NoError(t, err)
alice.SetRoles([]string{"access"})

authProcess, proxyProcess := makeTestServers(t, withBootstrap(connector, alice))
motd := "RELOGIN MOTD PRESENT"
authProcess, proxyProcess := makeTestServers(t,
withBootstrap(connector, alice),
withMOTD(t, motd),
)

authServer := authProcess.GetAuthServer()
require.NotNil(t, authServer)

proxyAddr, err := proxyProcess.ProxyWebAddr()
require.NoError(t, err)

buf := bytes.NewBuffer([]byte{})
sc := bufio.NewScanner(buf)
err = Run([]string{
"login",
"--insecure",
Expand All @@ -299,20 +325,32 @@ func TestRelogin(t *testing.T) {
"--proxy", proxyAddr.String(),
}, setHomePath(tmpHomePath), cliOption(func(cf *CLIConf) error {
cf.mockSSOLogin = mockSSOLogin(t, authServer, alice)
cf.overrideStderr = buf
return nil
}))
require.NoError(t, err)
findMOTD(t, sc, motd)

err = Run([]string{
"login",
"--insecure",
"--debug",
"--proxy", proxyAddr.String(),
"localhost",
}, setHomePath(tmpHomePath))
}, setHomePath(tmpHomePath),
cliOption(func(cf *CLIConf) error {
cf.mockSSOLogin = mockSSOLogin(t, authServer, alice)
cf.overrideStderr = buf
return nil
}))
require.NoError(t, err)
findMOTD(t, sc, motd)

err = Run([]string{"logout"}, setHomePath(tmpHomePath))
err = Run([]string{"logout"}, setHomePath(tmpHomePath),
cliOption(func(cf *CLIConf) error {
cf.overrideStderr = buf
return nil
}))
require.NoError(t, err)

err = Run([]string{
Expand All @@ -324,8 +362,10 @@ func TestRelogin(t *testing.T) {
"localhost",
}, setHomePath(tmpHomePath), cliOption(func(cf *CLIConf) error {
cf.mockSSOLogin = mockSSOLogin(t, authServer, alice)
cf.overrideStderr = buf
return nil
}))
findMOTD(t, sc, motd)
require.NoError(t, err)
}

Expand Down Expand Up @@ -1238,8 +1278,8 @@ func TestSetX11Config(t *testing.T) {
}

type testServersOpts struct {
bootstrap []types.Resource
authConfigFunc func(cfg *service.AuthConfig)
bootstrap []types.Resource
authConfigFuncs []func(cfg *service.AuthConfig)
}

type testServerOptFunc func(o *testServersOpts)
Expand All @@ -1252,7 +1292,11 @@ func withBootstrap(bootstrap ...types.Resource) testServerOptFunc {

func withAuthConfig(fn func(cfg *service.AuthConfig)) testServerOptFunc {
return func(o *testServersOpts) {
o.authConfigFunc = fn
if o.authConfigFuncs == nil {
o.authConfigFuncs = []func(cfg *service.AuthConfig){}
}

o.authConfigFuncs = append(o.authConfigFuncs, fn)
}
}

Expand All @@ -1267,6 +1311,17 @@ func withClusterName(t *testing.T, n string) testServerOptFunc {
})
}

func withMOTD(t *testing.T, motd string) testServerOptFunc {
oldpass := client.PasswordFromConsoleFn
*client.PasswordFromConsoleFn = func() (string, error) {
return "", nil
}
t.Cleanup(func() { *client.PasswordFromConsoleFn = *oldpass })
return withAuthConfig(func(cfg *service.AuthConfig) {
cfg.Preference.SetMessageOfTheDay(motd)
})
}

func makeTestServers(t *testing.T, opts ...testServerOptFunc) (auth *service.TeleportProcess, proxy *service.TeleportProcess) {
var options testServersOpts
for _, opt := range opts {
Expand Down Expand Up @@ -1299,8 +1354,8 @@ func makeTestServers(t *testing.T, opts ...testServerOptFunc) (auth *service.Tel
cfg.Proxy.Enabled = false
cfg.Log = utils.NewLoggerForTests()

if options.authConfigFunc != nil {
options.authConfigFunc(&cfg.Auth)
for _, fn := range options.authConfigFuncs {
fn(&cfg.Auth)
}

auth, err = service.NewTeleport(cfg)
Expand Down