diff --git a/lib/srv/alpnproxy/azure_msi_middleware.go b/lib/srv/alpnproxy/azure_msi_middleware.go index b26adc80a78fd..0875f9fcb8d9f 100644 --- a/lib/srv/alpnproxy/azure_msi_middleware.go +++ b/lib/srv/alpnproxy/azure_msi_middleware.go @@ -20,9 +20,11 @@ package alpnproxy import ( "crypto" + "crypto/tls" "encoding/json" "fmt" "net/http" + "sync" "time" "github.com/gravitational/trace" @@ -45,15 +47,16 @@ type AzureMSIMiddleware struct { // ClientID to be returned in a claim. ClientID string - // Key used to sign JWT - Key crypto.Signer - // Clock is used to override time in tests. Clock clockwork.Clock // Log is the Logger. Log logrus.FieldLogger // Secret to be provided by the client. Secret string + + // Key used to sign JWT + key crypto.Signer + keyMu sync.RWMutex } var _ LocalProxyHTTPMiddleware = &AzureMSIMiddleware{} @@ -66,9 +69,6 @@ func (m *AzureMSIMiddleware) CheckAndSetDefaults() error { m.Log = logrus.WithField(teleport.ComponentKey, "azure_msi") } - if m.Key == nil { - return trace.BadParameter("missing Key") - } if m.Secret == "" { return trace.BadParameter("missing Secret") } @@ -96,6 +96,26 @@ func (m *AzureMSIMiddleware) HandleRequest(rw http.ResponseWriter, req *http.Req return false } +func (m *AzureMSIMiddleware) OnSetCert(cert *tls.Certificate) { + m.keyMu.Lock() + defer m.keyMu.Unlock() + + if cert != nil { + // Note that the PrivateKey is most likely set by api/utils/keys.TLSCertificateForSigner + signer, ok := cert.PrivateKey.(crypto.Signer) + if ok { + m.key = signer + } else { + m.Log.Warn("Provided tls.Certificate has no valid private key") + } + } +} +func (m *AzureMSIMiddleware) getPrivateKey() crypto.Signer { + m.keyMu.RLock() + defer m.keyMu.RUnlock() + return m.key +} + func (m *AzureMSIMiddleware) msiEndpoint(rw http.ResponseWriter, req *http.Request) error { // request validation if req.URL.Path != ("/" + m.Secret) { @@ -176,7 +196,7 @@ func (m *AzureMSIMiddleware) toJWT(claims jwt.AzureTokenClaims) (string, error) // Create a new key that can sign and verify tokens. key, err := jwt.New(&jwt.Config{ Clock: m.Clock, - PrivateKey: m.Key, + PrivateKey: m.getPrivateKey(), ClusterName: types.TeleportAzureMSIEndpoint, // todo get cluster name }) if err != nil { diff --git a/lib/srv/alpnproxy/local_proxy.go b/lib/srv/alpnproxy/local_proxy.go index 0aa7e75a193f2..33ebcfbbec630 100644 --- a/lib/srv/alpnproxy/local_proxy.go +++ b/lib/srv/alpnproxy/local_proxy.go @@ -92,6 +92,8 @@ type LocalProxyConfig struct { CheckCertNeeded bool // verifyUpstreamConnection is a callback function to verify upstream connection state. verifyUpstreamConnection func(tls.ConnectionState) error + // TODO + onSetCert func(*tls.Certificate) } // LocalProxyMiddleware provides callback functions for LocalProxy. @@ -484,6 +486,11 @@ func (l *LocalProxy) SetCert(cert tls.Certificate) { l.certMu.Lock() defer l.certMu.Unlock() l.cfg.Cert = cert + + // Callback, if any. + if l.cfg.onSetCert != nil { + l.cfg.onSetCert(&cert) + } } // getCertForConn determines if certificates should be used when dialing diff --git a/lib/srv/alpnproxy/local_proxy_config_opt.go b/lib/srv/alpnproxy/local_proxy_config_opt.go index 7b520ce37611f..8ea8dda0bfdcf 100644 --- a/lib/srv/alpnproxy/local_proxy_config_opt.go +++ b/lib/srv/alpnproxy/local_proxy_config_opt.go @@ -170,3 +170,10 @@ func mySQLVersionToProto(database types.Database) string { // Include MySQL server version return string(common.ProtocolMySQLWithVerPrefix) + versionBase64 } + +func WithOnSetCert(callback func(*tls.Certificate)) LocalProxyConfigOpt { + return func(config *LocalProxyConfig) error { + config.onSetCert = callback + return nil + } +} diff --git a/tool/tsh/common/app_azure.go b/tool/tsh/common/app_azure.go index 5a8494aae5d3e..c1ffb91f8b42f 100644 --- a/tool/tsh/common/app_azure.go +++ b/tool/tsh/common/app_azure.go @@ -20,7 +20,6 @@ package common import ( "context" - "crypto" "fmt" "os" "os/exec" @@ -72,17 +71,11 @@ type azureApp struct { *localProxyApp cf *CLIConf - signer crypto.Signer msiSecret string } // newAzureApp creates a new Azure app. func newAzureApp(tc *client.TeleportClient, cf *CLIConf, appInfo *appInfo) (*azureApp, error) { - keyRing, err := tc.LocalAgent().GetCoreKeyRing() - if err != nil { - return nil, trace.Wrap(err) - } - msiSecret, err := getMSISecret() if err != nil { return nil, err @@ -91,7 +84,6 @@ func newAzureApp(tc *client.TeleportClient, cf *CLIConf, appInfo *appInfo) (*azu return &azureApp{ localProxyApp: newLocalProxyApp(tc, appInfo, cf.LocalProxyPort, cf.InsecureSkipVerify), cf: cf, - signer: keyRing.TLSPrivateKey, msiSecret: msiSecret, }, nil } @@ -133,7 +125,6 @@ func getMSISecret() (string, error) { // These calls are served entirely locally, which helps the overall performance experienced by the user. func (a *azureApp) StartLocalProxies(ctx context.Context) error { azureMiddleware := &alpnproxy.AzureMSIMiddleware{ - Key: a.signer, Secret: a.msiSecret, // we could, in principle, get the actual TenantID either from live data or from static configuration, // but at this moment there is no clear advantage over simply issuing a new random identifier. @@ -143,7 +134,11 @@ func (a *azureApp) StartLocalProxies(ctx context.Context) error { } // HTTPS proxy mode - err := a.StartLocalProxyWithForwarder(ctx, alpnproxy.MatchAzureRequests, alpnproxy.WithHTTPMiddleware(azureMiddleware)) + err := a.StartLocalProxyWithForwarder(ctx, + alpnproxy.MatchAzureRequests, + alpnproxy.WithHTTPMiddleware(azureMiddleware), + alpnproxy.WithOnSetCert(azureMiddleware.OnSetCert), + ) return trace.Wrap(err) }