diff --git a/api/profile/profile.go b/api/profile/profile.go index 7fef9b2cb0471..acd08fa99ab7e 100644 --- a/api/profile/profile.go +++ b/api/profile/profile.go @@ -109,8 +109,41 @@ func (p *Profile) TLSConfig() (*tls.Config, error) { return nil, trace.Wrap(err) } + pool, err := certPoolFromProfile(p) + if err != nil { + return nil, trace.Wrap(err) + } + + return &tls.Config{ + Certificates: []tls.Certificate{cert}, + RootCAs: pool, + }, nil +} + +func certPoolFromProfile(p *Profile) (*x509.CertPool, error) { + // Check if CAS dir exist if not try to load certs from legacy certs.pem file. + if _, err := os.Stat(p.TLSClusterCASDir()); err != nil { + if !os.IsNotExist(err) { + return nil, trace.Wrap(err) + } + pool, err := certPoolFromLegacyCAFile(p) + if err != nil { + return nil, trace.Wrap(err) + } + return pool, nil + } + + // Load CertPool from CAS directory. + pool, err := certPoolFromCASDir(p) + if err != nil { + return nil, trace.Wrap(err) + } + return pool, nil +} + +func certPoolFromCASDir(p *Profile) (*x509.CertPool, error) { pool := x509.NewCertPool() - err = filepath.Walk(p.TLSClusterCASDir(), func(path string, info fs.FileInfo, err error) error { + err := filepath.Walk(p.TLSClusterCASDir(), func(path string, info fs.FileInfo, err error) error { if err != nil { return trace.Wrap(err) } @@ -129,11 +162,19 @@ func (p *Profile) TLSConfig() (*tls.Config, error) { if err != nil { return nil, trace.Wrap(err) } + return pool, nil +} - return &tls.Config{ - Certificates: []tls.Certificate{cert}, - RootCAs: pool, - }, nil +func certPoolFromLegacyCAFile(p *Profile) (*x509.CertPool, error) { + caCerts, err := os.ReadFile(p.TLSCAsPath()) + if err != nil { + return nil, trace.ConvertSystemError(err) + } + pool := x509.NewCertPool() + if !pool.AppendCertsFromPEM(caCerts) { + return nil, trace.BadParameter("invalid CA cert PEM") + } + return pool, nil } // SSHClientConfig returns the profile's associated SSHClientConfig. @@ -336,6 +377,11 @@ func (p *Profile) TLSCertPath() string { return keypaths.TLSCertPath(p.Dir, p.Name(), p.Username) } +// TLSCAsLegacyPath returns the path to the profile's TLS certificate authorities. +func (p *Profile) TLSCAsLegacyPath() string { + return keypaths.TLSCAsPath(p.Dir, p.Name()) +} + // TLSCAPathCluster returns CA for particular cluster. func (p *Profile) TLSCAPathCluster(cluster string) string { return keypaths.TLSCAsPathCluster(p.Dir, p.Name(), cluster) @@ -346,6 +392,11 @@ func (p *Profile) TLSClusterCASDir() string { return keypaths.CAsDir(p.Dir, p.Name()) } +// TLSCAsPath returns the legacy path to the profile's TLS certificate authorities. +func (p *Profile) TLSCAsPath() string { + return keypaths.TLSCAsPath(p.Dir, p.Name()) +} + // SSHDir returns the path to the profile's ssh directory. func (p *Profile) SSHDir() string { return keypaths.SSHDir(p.Dir, p.Name(), p.Username) diff --git a/api/utils/keypaths/keypaths.go b/api/utils/keypaths/keypaths.go index 8ea41c49a30b0..9eb3b53339553 100644 --- a/api/utils/keypaths/keypaths.go +++ b/api/utils/keypaths/keypaths.go @@ -34,6 +34,8 @@ const ( fileNameKnownHosts = "known_hosts" // fileExtTLSCert is the suffix/extension of a file where a TLS cert is stored. fileExtTLSCert = "-x509.pem" + // fileNameTLSCerts is a file where TLS Cert Authorities are stored. + fileNameTLSCerts = "certs.pem" // fileExtCert is the suffix/extension of a file where an SSH Cert is stored. fileExtSSHCert = "-cert.pub" // fileExtPub is the extension of a file where a public key is stored. @@ -144,6 +146,14 @@ func CAsDir(baseDir, proxy string) string { return filepath.Join(ProxyKeyDir(baseDir, proxy), casDir) } +// TLSCAsPath returns the path to the users's TLS CA's certificates +// for the given proxy. +// /keys//certs.pem +// DELETE IN 10.0. Deprecated +func TLSCAsPath(baseDir, proxy string) string { + return filepath.Join(ProxyKeyDir(baseDir, proxy), fileNameTLSCerts) +} + // TLSCAsPathCluster returns the path to the specified cluster's CA directory. // // /keys//cas/.pem diff --git a/lib/client/keystore.go b/lib/client/keystore.go index 868bda99381e4..47876a3c25216 100644 --- a/lib/client/keystore.go +++ b/lib/client/keystore.go @@ -520,6 +520,11 @@ func (fs *fsLocalNonSessionKeyStore) tlsCertPath(idx KeyIndex) string { return keypaths.TLSCertPath(fs.KeyDir, idx.ProxyHost, idx.Username) } +// tlsCAsPath returns the TLS CA certificates path for the given KeyIndex. +func (fs *fsLocalNonSessionKeyStore) tlsCAsPath(proxy string) string { + return keypaths.TLSCAsPath(fs.KeyDir, proxy) +} + // sshCertPath returns the SSH certificate path for the given KeyIndex. func (fs *fsLocalNonSessionKeyStore) sshCertPath(idx KeyIndex) string { return keypaths.SSHCertPath(fs.KeyDir, idx.ProxyHost, idx.Username, idx.ClusterName) @@ -680,6 +685,20 @@ func (fs *fsLocalNonSessionKeyStore) SaveTrustedCerts(proxyHost string, cas []au return trace.ConvertSystemError(err) } + // Save trusted clusters certs in CAS directory. + if err := fs.saveTrustedCertsInCASDir(proxyHost, cas); err != nil { + return trace.Wrap(err) + } + + // For backward compatibility save trusted in legacy certs.pem file. + if err := fs.saveTrustedCertsInLegacyCAFile(proxyHost, cas); err != nil { + return trace.Wrap(err) + } + + return nil +} + +func (fs *fsLocalNonSessionKeyStore) saveTrustedCertsInCASDir(proxyHost string, cas []auth.TrustedCerts) error { casDirPath := filepath.Join(fs.casDir(proxyHost)) if err := os.MkdirAll(casDirPath, os.ModeDir|profileDirPerms); err != nil { fs.log.Error(err) @@ -700,11 +719,30 @@ func (fs *fsLocalNonSessionKeyStore) SaveTrustedCerts(proxyHost string, cas []au if err := writeClusterCertificates(caFile, ca.TLSCertificates); err != nil { return trace.Wrap(err) } - } return nil } +func (fs *fsLocalNonSessionKeyStore) saveTrustedCertsInLegacyCAFile(proxyHost string, cas []auth.TrustedCerts) (retErr error) { + certsFile := fs.tlsCAsPath(proxyHost) + fp, err := os.OpenFile(certsFile, os.O_CREATE|os.O_RDWR|os.O_TRUNC, 0640) + if err != nil { + return trace.ConvertSystemError(err) + } + defer utils.StoreErrorOf(fp.Close, &retErr) + for _, ca := range cas { + for _, cert := range ca.TLSCertificates { + if _, err := fp.Write(cert); err != nil { + return trace.ConvertSystemError(err) + } + if _, err := fmt.Fprintln(fp); err != nil { + return trace.ConvertSystemError(err) + } + } + } + return fp.Sync() +} + // isSafeClusterName check if cluster name is safe and doesn't contain miscellaneous characters. func isSafeClusterName(name string) bool { return !strings.Contains(name, "..") diff --git a/tool/tsh/tsh_helper_test.go b/tool/tsh/tsh_helper_test.go index 7196009f52c34..187676d63e1a7 100644 --- a/tool/tsh/tsh_helper_test.go +++ b/tool/tsh/tsh_helper_test.go @@ -17,6 +17,7 @@ limitations under the License. package main import ( + "context" "fmt" "os/user" "testing" @@ -24,6 +25,7 @@ import ( "github.com/stretchr/testify/require" + apiclient "github.com/gravitational/teleport/api/client" "github.com/gravitational/teleport/api/types" "github.com/gravitational/teleport/lib/config" "github.com/gravitational/teleport/lib/service" @@ -246,3 +248,16 @@ func waitForEvents(t *testing.T, svc service.Supervisor, events ...string) { } } } + +func mustCreateAuthClientFormUserProfile(t *testing.T, tshHomePath, addr string) { + ctx := context.Background() + credentials := apiclient.LoadProfile(tshHomePath, "") + c, err := apiclient.New(context.Background(), apiclient.Config{ + Addrs: []string{addr}, + Credentials: []apiclient.Credentials{credentials}, + InsecureAddressDiscovery: true, + }) + require.NoError(t, err) + _, err = c.Ping(ctx) + require.NoError(t, err) +} diff --git a/tool/tsh/tsh_test.go b/tool/tsh/tsh_test.go index 961754f94fbb0..b33150f0b65e7 100644 --- a/tool/tsh/tsh_test.go +++ b/tool/tsh/tsh_test.go @@ -36,6 +36,7 @@ import ( "github.com/gravitational/teleport/api/constants" apidefaults "github.com/gravitational/teleport/api/defaults" + "github.com/gravitational/teleport/api/profile" "github.com/gravitational/teleport/api/types" "github.com/gravitational/teleport/lib" "github.com/gravitational/teleport/lib/auth" @@ -1234,6 +1235,45 @@ func TestSetX11Config(t *testing.T) { } } +// TestAuthClientFromTSHProfile tests if API Client can be successfully created from tsh profile where clusters +// certs are stored separately in CAS directory and in case where legacy certs.pem file was used. +func TestAuthClientFromTSHProfile(t *testing.T) { + tmpHomePath := t.TempDir() + + connector := mockConnector(t) + alice, err := types.NewUser("alice@example.com") + require.NoError(t, err) + alice.SetRoles([]string{"access"}) + authProcess, proxyProcess := makeTestServers(t, withBootstrap(connector, alice)) + authServer := authProcess.GetAuthServer() + require.NotNil(t, authServer) + proxyAddr, err := proxyProcess.ProxyWebAddr() + require.NoError(t, err) + + err = Run([]string{ + "login", + "--insecure", + "--debug", + "--auth", connector.GetName(), + "--proxy", proxyAddr.String(), + }, setHomePath(tmpHomePath), func(cf *CLIConf) error { + cf.mockSSOLogin = mockSSOLogin(t, authServer, alice) + return nil + }) + require.NoError(t, err) + + profile, err := profile.FromDir(tmpHomePath, "") + require.NoError(t, err) + + mustCreateAuthClientFormUserProfile(t, tmpHomePath, proxyAddr.String()) + + // Simulate legacy tsh client behavior where all clusters certs were stored in the certs.pem file. + require.NoError(t, os.RemoveAll(profile.TLSClusterCASDir())) + + // Verify that authClient created from profile will create a valid client in case where cas dir doesn't exit. + mustCreateAuthClientFormUserProfile(t, tmpHomePath, proxyAddr.String()) +} + type testServersOpts struct { bootstrap []types.Resource authConfigFuncs []func(cfg *service.AuthConfig)