Skip to content

Commit

Permalink
Add support for backward compatible API Client behavior (#11567)
Browse files Browse the repository at this point in the history
  • Loading branch information
smallinsky authored Apr 1, 2022
1 parent 3999b17 commit dc09f6f
Show file tree
Hide file tree
Showing 5 changed files with 160 additions and 6 deletions.
61 changes: 56 additions & 5 deletions api/profile/profile.go
Original file line number Diff line number Diff line change
Expand Up @@ -109,8 +109,41 @@ func (p *Profile) TLSConfig() (*tls.Config, error) {
return nil, trace.Wrap(err)
}

pool, err := certPoolFromProfile(p)
if err != nil {
return nil, trace.Wrap(err)
}

return &tls.Config{
Certificates: []tls.Certificate{cert},
RootCAs: pool,
}, nil
}

func certPoolFromProfile(p *Profile) (*x509.CertPool, error) {
// Check if CAS dir exist if not try to load certs from legacy certs.pem file.
if _, err := os.Stat(p.TLSClusterCASDir()); err != nil {
if !os.IsNotExist(err) {
return nil, trace.Wrap(err)
}
pool, err := certPoolFromLegacyCAFile(p)
if err != nil {
return nil, trace.Wrap(err)
}
return pool, nil
}

// Load CertPool from CAS directory.
pool, err := certPoolFromCASDir(p)
if err != nil {
return nil, trace.Wrap(err)
}
return pool, nil
}

func certPoolFromCASDir(p *Profile) (*x509.CertPool, error) {
pool := x509.NewCertPool()
err = filepath.Walk(p.TLSClusterCASDir(), func(path string, info fs.FileInfo, err error) error {
err := filepath.Walk(p.TLSClusterCASDir(), func(path string, info fs.FileInfo, err error) error {
if err != nil {
return trace.Wrap(err)
}
Expand All @@ -129,11 +162,19 @@ func (p *Profile) TLSConfig() (*tls.Config, error) {
if err != nil {
return nil, trace.Wrap(err)
}
return pool, nil
}

return &tls.Config{
Certificates: []tls.Certificate{cert},
RootCAs: pool,
}, nil
func certPoolFromLegacyCAFile(p *Profile) (*x509.CertPool, error) {
caCerts, err := os.ReadFile(p.TLSCAsPath())
if err != nil {
return nil, trace.ConvertSystemError(err)
}
pool := x509.NewCertPool()
if !pool.AppendCertsFromPEM(caCerts) {
return nil, trace.BadParameter("invalid CA cert PEM")
}
return pool, nil
}

// SSHClientConfig returns the profile's associated SSHClientConfig.
Expand Down Expand Up @@ -336,6 +377,11 @@ func (p *Profile) TLSCertPath() string {
return keypaths.TLSCertPath(p.Dir, p.Name(), p.Username)
}

// TLSCAsLegacyPath returns the path to the profile's TLS certificate authorities.
func (p *Profile) TLSCAsLegacyPath() string {
return keypaths.TLSCAsPath(p.Dir, p.Name())
}

// TLSCAPathCluster returns CA for particular cluster.
func (p *Profile) TLSCAPathCluster(cluster string) string {
return keypaths.TLSCAsPathCluster(p.Dir, p.Name(), cluster)
Expand All @@ -346,6 +392,11 @@ func (p *Profile) TLSClusterCASDir() string {
return keypaths.CAsDir(p.Dir, p.Name())
}

// TLSCAsPath returns the legacy path to the profile's TLS certificate authorities.
func (p *Profile) TLSCAsPath() string {
return keypaths.TLSCAsPath(p.Dir, p.Name())
}

// SSHDir returns the path to the profile's ssh directory.
func (p *Profile) SSHDir() string {
return keypaths.SSHDir(p.Dir, p.Name(), p.Username)
Expand Down
10 changes: 10 additions & 0 deletions api/utils/keypaths/keypaths.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ const (
fileNameKnownHosts = "known_hosts"
// fileExtTLSCert is the suffix/extension of a file where a TLS cert is stored.
fileExtTLSCert = "-x509.pem"
// fileNameTLSCerts is a file where TLS Cert Authorities are stored.
fileNameTLSCerts = "certs.pem"
// fileExtCert is the suffix/extension of a file where an SSH Cert is stored.
fileExtSSHCert = "-cert.pub"
// fileExtPub is the extension of a file where a public key is stored.
Expand Down Expand Up @@ -144,6 +146,14 @@ func CAsDir(baseDir, proxy string) string {
return filepath.Join(ProxyKeyDir(baseDir, proxy), casDir)
}

// TLSCAsPath returns the path to the users's TLS CA's certificates
// for the given proxy.
// <baseDir>/keys/<proxy>/certs.pem
// DELETE IN 10.0. Deprecated
func TLSCAsPath(baseDir, proxy string) string {
return filepath.Join(ProxyKeyDir(baseDir, proxy), fileNameTLSCerts)
}

// TLSCAsPathCluster returns the path to the specified cluster's CA directory.
//
// <baseDir>/keys/<proxy>/cas/<cluster>.pem
Expand Down
40 changes: 39 additions & 1 deletion lib/client/keystore.go
Original file line number Diff line number Diff line change
Expand Up @@ -520,6 +520,11 @@ func (fs *fsLocalNonSessionKeyStore) tlsCertPath(idx KeyIndex) string {
return keypaths.TLSCertPath(fs.KeyDir, idx.ProxyHost, idx.Username)
}

// tlsCAsPath returns the TLS CA certificates path for the given KeyIndex.
func (fs *fsLocalNonSessionKeyStore) tlsCAsPath(proxy string) string {
return keypaths.TLSCAsPath(fs.KeyDir, proxy)
}

// sshCertPath returns the SSH certificate path for the given KeyIndex.
func (fs *fsLocalNonSessionKeyStore) sshCertPath(idx KeyIndex) string {
return keypaths.SSHCertPath(fs.KeyDir, idx.ProxyHost, idx.Username, idx.ClusterName)
Expand Down Expand Up @@ -680,6 +685,20 @@ func (fs *fsLocalNonSessionKeyStore) SaveTrustedCerts(proxyHost string, cas []au
return trace.ConvertSystemError(err)
}

// Save trusted clusters certs in CAS directory.
if err := fs.saveTrustedCertsInCASDir(proxyHost, cas); err != nil {
return trace.Wrap(err)
}

// For backward compatibility save trusted in legacy certs.pem file.
if err := fs.saveTrustedCertsInLegacyCAFile(proxyHost, cas); err != nil {
return trace.Wrap(err)
}

return nil
}

func (fs *fsLocalNonSessionKeyStore) saveTrustedCertsInCASDir(proxyHost string, cas []auth.TrustedCerts) error {
casDirPath := filepath.Join(fs.casDir(proxyHost))
if err := os.MkdirAll(casDirPath, os.ModeDir|profileDirPerms); err != nil {
fs.log.Error(err)
Expand All @@ -700,11 +719,30 @@ func (fs *fsLocalNonSessionKeyStore) SaveTrustedCerts(proxyHost string, cas []au
if err := writeClusterCertificates(caFile, ca.TLSCertificates); err != nil {
return trace.Wrap(err)
}

}
return nil
}

func (fs *fsLocalNonSessionKeyStore) saveTrustedCertsInLegacyCAFile(proxyHost string, cas []auth.TrustedCerts) (retErr error) {
certsFile := fs.tlsCAsPath(proxyHost)
fp, err := os.OpenFile(certsFile, os.O_CREATE|os.O_RDWR|os.O_TRUNC, 0640)
if err != nil {
return trace.ConvertSystemError(err)
}
defer utils.StoreErrorOf(fp.Close, &retErr)
for _, ca := range cas {
for _, cert := range ca.TLSCertificates {
if _, err := fp.Write(cert); err != nil {
return trace.ConvertSystemError(err)
}
if _, err := fmt.Fprintln(fp); err != nil {
return trace.ConvertSystemError(err)
}
}
}
return fp.Sync()
}

// isSafeClusterName check if cluster name is safe and doesn't contain miscellaneous characters.
func isSafeClusterName(name string) bool {
return !strings.Contains(name, "..")
Expand Down
15 changes: 15 additions & 0 deletions tool/tsh/tsh_helper_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,15 @@ limitations under the License.
package main

import (
"context"
"fmt"
"os/user"
"testing"
"time"

"github.com/stretchr/testify/require"

apiclient "github.com/gravitational/teleport/api/client"
"github.com/gravitational/teleport/api/types"
"github.com/gravitational/teleport/lib/config"
"github.com/gravitational/teleport/lib/service"
Expand Down Expand Up @@ -246,3 +248,16 @@ func waitForEvents(t *testing.T, svc service.Supervisor, events ...string) {
}
}
}

func mustCreateAuthClientFormUserProfile(t *testing.T, tshHomePath, addr string) {
ctx := context.Background()
credentials := apiclient.LoadProfile(tshHomePath, "")
c, err := apiclient.New(context.Background(), apiclient.Config{
Addrs: []string{addr},
Credentials: []apiclient.Credentials{credentials},
InsecureAddressDiscovery: true,
})
require.NoError(t, err)
_, err = c.Ping(ctx)
require.NoError(t, err)
}
40 changes: 40 additions & 0 deletions tool/tsh/tsh_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ import (

"github.com/gravitational/teleport/api/constants"
apidefaults "github.com/gravitational/teleport/api/defaults"
"github.com/gravitational/teleport/api/profile"
"github.com/gravitational/teleport/api/types"
"github.com/gravitational/teleport/lib"
"github.com/gravitational/teleport/lib/auth"
Expand Down Expand Up @@ -1234,6 +1235,45 @@ func TestSetX11Config(t *testing.T) {
}
}

// TestAuthClientFromTSHProfile tests if API Client can be successfully created from tsh profile where clusters
// certs are stored separately in CAS directory and in case where legacy certs.pem file was used.
func TestAuthClientFromTSHProfile(t *testing.T) {
tmpHomePath := t.TempDir()

connector := mockConnector(t)
alice, err := types.NewUser("[email protected]")
require.NoError(t, err)
alice.SetRoles([]string{"access"})
authProcess, proxyProcess := makeTestServers(t, withBootstrap(connector, alice))
authServer := authProcess.GetAuthServer()
require.NotNil(t, authServer)
proxyAddr, err := proxyProcess.ProxyWebAddr()
require.NoError(t, err)

err = Run([]string{
"login",
"--insecure",
"--debug",
"--auth", connector.GetName(),
"--proxy", proxyAddr.String(),
}, setHomePath(tmpHomePath), func(cf *CLIConf) error {
cf.mockSSOLogin = mockSSOLogin(t, authServer, alice)
return nil
})
require.NoError(t, err)

profile, err := profile.FromDir(tmpHomePath, "")
require.NoError(t, err)

mustCreateAuthClientFormUserProfile(t, tmpHomePath, proxyAddr.String())

// Simulate legacy tsh client behavior where all clusters certs were stored in the certs.pem file.
require.NoError(t, os.RemoveAll(profile.TLSClusterCASDir()))

// Verify that authClient created from profile will create a valid client in case where cas dir doesn't exit.
mustCreateAuthClientFormUserProfile(t, tmpHomePath, proxyAddr.String())
}

type testServersOpts struct {
bootstrap []types.Resource
authConfigFuncs []func(cfg *service.AuthConfig)
Expand Down

0 comments on commit dc09f6f

Please sign in to comment.