Skip to content

Commit

Permalink
Convert lib/srv/alpnproxy to use slog (#50018)
Browse files Browse the repository at this point in the history
  • Loading branch information
rosstimothy authored Dec 12, 2024
1 parent da6dfdc commit 41bae13
Show file tree
Hide file tree
Showing 25 changed files with 127 additions and 124 deletions.
2 changes: 1 addition & 1 deletion integration/proxy/proxy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1236,7 +1236,7 @@ func TestALPNSNIProxyDatabaseAccess(t *testing.T) {

// advance the fake clock and verify that the local proxy thinks its cert expired.
fakeClock.Advance(time.Hour * 48)
err = lp.CheckDBCert(routeToDatabase)
err = lp.CheckDBCert(context.Background(), routeToDatabase)
require.Error(t, err)
var x509Err x509.CertificateInvalidError
require.ErrorAs(t, err, &x509Err)
Expand Down
12 changes: 6 additions & 6 deletions lib/auth/authclient/tls.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,10 @@ import (
"context"
"crypto/tls"
"crypto/x509"
"log/slog"
"math"

"github.com/gravitational/trace"
"github.com/sirupsen/logrus"

"github.com/gravitational/teleport/api/types"
apiutils "github.com/gravitational/teleport/api/utils"
Expand Down Expand Up @@ -96,7 +96,7 @@ func DefaultClientCertPool(ctx context.Context, client CAGetter, clusterName str

// WithClusterCAs returns a TLS hello callback that returns a copy of the provided
// TLS config with client CAs pool of the specified cluster.
func WithClusterCAs(tlsConfig *tls.Config, ap CAGetter, currentClusterName string, log logrus.FieldLogger) func(*tls.ClientHelloInfo) (*tls.Config, error) {
func WithClusterCAs(tlsConfig *tls.Config, ap CAGetter, currentClusterName string, logger *slog.Logger) func(*tls.ClientHelloInfo) (*tls.Config, error) {
return func(info *tls.ClientHelloInfo) (*tls.Config, error) {
var clusterName string
var err error
Expand All @@ -105,14 +105,14 @@ func WithClusterCAs(tlsConfig *tls.Config, ap CAGetter, currentClusterName strin
clusterName, err = apiutils.DecodeClusterName(info.ServerName)
if err != nil {
if !trace.IsNotFound(err) {
log.Debugf("Ignoring unsupported cluster name name %q.", info.ServerName)
logger.DebugContext(info.Context(), "Ignoring unsupported cluster name name", "cluster_name", info.ServerName)
clusterName = ""
}
}
}
pool, totalSubjectsLen, err := DefaultClientCertPool(info.Context(), ap, clusterName)
if err != nil {
log.WithError(err).Errorf("Failed to retrieve client pool for %q.", clusterName)
logger.ErrorContext(info.Context(), "Failed to retrieve client pool for cluster", "error", err, "cluster", clusterName)
// this falls back to the default config
return nil, nil
}
Expand All @@ -130,11 +130,11 @@ func WithClusterCAs(tlsConfig *tls.Config, ap CAGetter, currentClusterName strin
// the current cluster CA. In the unlikely case where it's wrong, the
// client will be rejected.
if totalSubjectsLen >= int64(math.MaxUint16) {
log.Debugf("Number of CAs in client cert pool is too large and cannot be encoded in a TLS handshake; this is due to a large number of trusted clusters; will use only the CA of the current cluster to validate.")
logger.DebugContext(info.Context(), "Number of CAs in client cert pool is too large and cannot be encoded in a TLS handshake; this is due to a large number of trusted clusters; will use only the CA of the current cluster to validate")

pool, _, err = DefaultClientCertPool(info.Context(), ap, currentClusterName)
if err != nil {
log.WithError(err).Errorf("Failed to retrieve client pool for %q.", currentClusterName)
logger.ErrorContext(info.Context(), "Failed to retrieve client pool for cluster", "error", err, "cluster", currentClusterName)
// this falls back to the default config
return nil, nil
}
Expand Down
8 changes: 4 additions & 4 deletions lib/kube/grpc/grpc_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,12 @@ import (
"context"
"crypto/tls"
"crypto/x509"
"log/slog"
"net"
"testing"

"github.com/google/go-cmp/cmp"
"github.com/gravitational/trace"
"github.com/sirupsen/logrus"
"github.com/stretchr/testify/require"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials"
Expand Down Expand Up @@ -639,7 +639,7 @@ func initGRPCServer(t *testing.T, testCtx *kubeproxy.TestContext, listener net.L
AcceptedUsage: []string{teleport.UsageKubeOnly},
}

tlsConf := copyAndConfigureTLS(tlsConfig, logrus.New(), testCtx.AuthClient, clusterName)
tlsConf := copyAndConfigureTLS(tlsConfig, testCtx.AuthClient, clusterName)
creds, err := auth.NewTransportCredentials(auth.TransportCredentialsConfig{
TransportCredentials: credentials.NewTLS(tlsConf),
UserGetter: authMiddleware,
Expand Down Expand Up @@ -693,7 +693,7 @@ func initGRPCServer(t *testing.T, testCtx *kubeproxy.TestContext, listener net.L

// copyAndConfigureTLS can be used to copy and modify an existing *tls.Config
// for Teleport application proxy servers.
func copyAndConfigureTLS(config *tls.Config, log logrus.FieldLogger, accessPoint authclient.AccessCache, clusterName string) *tls.Config {
func copyAndConfigureTLS(config *tls.Config, accessPoint authclient.AccessCache, clusterName string) *tls.Config {
tlsConfig := config.Clone()

// Require clients to present a certificate
Expand All @@ -703,7 +703,7 @@ func copyAndConfigureTLS(config *tls.Config, log logrus.FieldLogger, accessPoint
// client's certificate to verify the chain presented. If the client does not
// pass in the cluster name, this functions pulls back all CA to try and
// match the certificate presented against any CA.
tlsConfig.GetConfigForClient = authclient.WithClusterCAs(tlsConfig.Clone(), accessPoint, clusterName, log)
tlsConfig.GetConfigForClient = authclient.WithClusterCAs(tlsConfig.Clone(), accessPoint, clusterName, slog.Default())

return tlsConfig
}
Expand Down
4 changes: 3 additions & 1 deletion lib/kube/proxy/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ package proxy
import (
"context"
"crypto/tls"
"log/slog"
"maps"
"net"
"net/http"
Expand Down Expand Up @@ -421,7 +422,8 @@ func (t *TLSServer) close(ctx context.Context) error {
// and server's GetConfigForClient reloads the list of trusted
// local and remote certificate authorities
func (t *TLSServer) GetConfigForClient(info *tls.ClientHelloInfo) (*tls.Config, error) {
return authclient.WithClusterCAs(t.TLS, t.AccessPoint, t.ClusterName, t.log)(info)
// TODO(tross): remove slog.Default once the TLSServer is updated to use a slog.Logger
return authclient.WithClusterCAs(t.TLS, t.AccessPoint, t.ClusterName, slog.Default())(info)
}

// GetServerInfo returns a services.Server object for heartbeats (aka
Expand Down
5 changes: 2 additions & 3 deletions lib/service/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -5273,7 +5273,6 @@ func (process *TeleportProcess) initProxyEndpoint(conn *Connector) error {
clusterName,
utils.NetAddrsToStrings(process.Config.AuthServerAddresses()),
proxySigner,
process.log,
process.TracingProvider.Tracer(teleport.ComponentProxy))

alpnRouter.Add(alpnproxy.HandlerDecs{
Expand Down Expand Up @@ -6698,7 +6697,7 @@ func (process *TeleportProcess) initSecureGRPCServer(cfg initSecureGRPCServerCfg

tlsConf := serverTLSConfig.Clone()
tlsConf.NextProtos = []string{string(alpncommon.ProtocolHTTP2), string(alpncommon.ProtocolProxyGRPCSecure)}
tlsConf = copyAndConfigureTLS(tlsConf, process.log, cfg.accessPoint, clusterName)
tlsConf = copyAndConfigureTLS(tlsConf, process.logger, cfg.accessPoint, clusterName)
creds, err := auth.NewTransportCredentials(auth.TransportCredentialsConfig{
TransportCredentials: credentials.NewTLS(tlsConf),
UserGetter: authMiddleware,
Expand Down Expand Up @@ -6752,7 +6751,7 @@ type initSecureGRPCServerCfg struct {

// copyAndConfigureTLS can be used to copy and modify an existing *tls.Config
// for Teleport application proxy servers.
func copyAndConfigureTLS(config *tls.Config, log logrus.FieldLogger, accessPoint authclient.AccessCache, clusterName string) *tls.Config {
func copyAndConfigureTLS(config *tls.Config, log *slog.Logger, accessPoint authclient.AccessCache, clusterName string) *tls.Config {
tlsConfig := config.Clone()

// Require clients to present a certificate
Expand Down
5 changes: 1 addition & 4 deletions lib/srv/alpnproxy/auth/auth_proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ import (
"strings"

"github.com/gravitational/trace"
"github.com/sirupsen/logrus"
"go.opentelemetry.io/otel/attribute"
oteltrace "go.opentelemetry.io/otel/trace"

Expand All @@ -44,13 +43,12 @@ type sitesGetter interface {
}

// NewAuthProxyDialerService create new instance of AuthProxyDialerService.
func NewAuthProxyDialerService(reverseTunnelServer sitesGetter, localClusterName string, authServers []string, proxySigner multiplexer.PROXYHeaderSigner, log logrus.FieldLogger, tracer oteltrace.Tracer) *AuthProxyDialerService {
func NewAuthProxyDialerService(reverseTunnelServer sitesGetter, localClusterName string, authServers []string, proxySigner multiplexer.PROXYHeaderSigner, tracer oteltrace.Tracer) *AuthProxyDialerService {
return &AuthProxyDialerService{
reverseTunnelServer: reverseTunnelServer,
localClusterName: localClusterName,
authServers: authServers,
proxySigner: proxySigner,
log: log,
tracer: tracer,
}
}
Expand All @@ -62,7 +60,6 @@ type AuthProxyDialerService struct {
localClusterName string
authServers []string
proxySigner multiplexer.PROXYHeaderSigner
log logrus.FieldLogger
tracer oteltrace.Tracer
}

Expand Down
6 changes: 3 additions & 3 deletions lib/srv/alpnproxy/auth/auth_proxy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ import (
)

func TestDialLocalAuthServerNoServers(t *testing.T) {
s := NewAuthProxyDialerService(nil /* reverseTunnelServer */, "clustername", nil /* authServers */, nil, nil, tracing.NoopTracer("test"))
s := NewAuthProxyDialerService(nil /* reverseTunnelServer */, "clustername", nil /* authServers */, nil, tracing.NoopTracer("test"))
_, err := s.dialLocalAuthServer(context.Background(), nil, nil)
require.Error(t, err, "dialLocalAuthServer expected to fail")
require.Equal(t, "empty auth servers list", err.Error())
Expand All @@ -40,7 +40,7 @@ func TestDialLocalAuthServerNoServers(t *testing.T) {
func TestDialLocalAuthServerNoAvailableServers(t *testing.T) {
// The 203.0.113.0/24 range is part of block TEST-NET-3 as defined in RFC-5735 (https://www.rfc-editor.org/rfc/rfc5735).
// IPs in this range do not appear on the public internet.
s := NewAuthProxyDialerService(nil /* reverseTunnelServer */, "clustername", []string{"203.0.113.1:3025"}, nil, nil, tracing.NoopTracer("test"))
s := NewAuthProxyDialerService(nil /* reverseTunnelServer */, "clustername", []string{"203.0.113.1:3025"}, nil, tracing.NoopTracer("test"))
ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond)
t.Cleanup(cancel)
_, err := s.dialLocalAuthServer(ctx, nil, nil)
Expand All @@ -64,7 +64,7 @@ func TestDialLocalAuthServerAvailableServers(t *testing.T) {
// IPs in this range do not appear on the public internet.
authServers = append(authServers, fmt.Sprintf("203.0.113.%d:3025", i+1))
}
s := NewAuthProxyDialerService(nil /* reverseTunnelServer */, "clustername", authServers, nil, nil, tracing.NoopTracer("test"))
s := NewAuthProxyDialerService(nil /* reverseTunnelServer */, "clustername", authServers, nil, tracing.NoopTracer("test"))
require.Eventually(t, func() bool {
ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond)
t.Cleanup(cancel)
Expand Down
8 changes: 4 additions & 4 deletions lib/srv/alpnproxy/auth_checker_middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,10 @@ package alpnproxy

import (
"crypto/subtle"
"log/slog"
"net/http"

"github.com/gravitational/trace"
"github.com/sirupsen/logrus"

"github.com/gravitational/teleport"
)
Expand All @@ -35,7 +35,7 @@ type AuthorizationCheckerMiddleware struct {
DefaultLocalProxyHTTPMiddleware

// Log is the Logger.
Log logrus.FieldLogger
Log *slog.Logger
// Secret is the expected value of a bearer token.
Secret string
}
Expand All @@ -45,7 +45,7 @@ var _ LocalProxyHTTPMiddleware = (*AuthorizationCheckerMiddleware)(nil)
// CheckAndSetDefaults checks configuration validity and sets defaults.
func (m *AuthorizationCheckerMiddleware) CheckAndSetDefaults() error {
if m.Log == nil {
m.Log = logrus.WithField(teleport.ComponentKey, "gcp")
m.Log = slog.With(teleport.ComponentKey, "authz")
}

if m.Secret == "" {
Expand All @@ -58,7 +58,7 @@ func (m *AuthorizationCheckerMiddleware) CheckAndSetDefaults() error {
func (m *AuthorizationCheckerMiddleware) HandleRequest(rw http.ResponseWriter, req *http.Request) bool {
auth := req.Header.Get("Authorization")
if auth == "" {
m.Log.Debugf("No Authorization header present, ignoring request.")
m.Log.DebugContext(req.Context(), "No Authorization header present, ignoring request")
return false
}

Expand Down
20 changes: 10 additions & 10 deletions lib/srv/alpnproxy/aws_local_proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,14 @@
package alpnproxy

import (
"log/slog"
"net/http"
"strings"

"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/credentials"
"github.com/aws/aws-sdk-go-v2/service/sts"
"github.com/gravitational/trace"
"github.com/sirupsen/logrus"

"github.com/gravitational/teleport"
awsapiutils "github.com/gravitational/teleport/api/utils/aws"
Expand All @@ -43,7 +43,7 @@ type AWSAccessMiddleware struct {
// signature verification.
AWSCredentialsProvider aws.CredentialsProvider

Log logrus.FieldLogger
Log *slog.Logger

assumedRoles utils.SyncMap[string, *sts.AssumeRoleOutput]
}
Expand All @@ -52,7 +52,7 @@ var _ LocalProxyHTTPMiddleware = &AWSAccessMiddleware{}

func (m *AWSAccessMiddleware) CheckAndSetDefaults() error {
if m.Log == nil {
m.Log = logrus.WithField(teleport.ComponentKey, "aws_access")
m.Log = slog.With(teleport.ComponentKey, "aws_access")
}

if m.AWSCredentialsProvider == nil {
Expand Down Expand Up @@ -113,7 +113,7 @@ func (m *AWSAccessMiddleware) CheckAndSetDefaults() error {
func (m *AWSAccessMiddleware) HandleRequest(rw http.ResponseWriter, req *http.Request) bool {
sigV4, err := awsutils.ParseSigV4(req.Header.Get(awsutils.AuthorizationHeader))
if err != nil {
m.Log.WithError(err).Error("Failed to parse AWS request authorization header.")
m.Log.ErrorContext(req.Context(), "Failed to parse AWS request authorization header", "error", err)
rw.WriteHeader(http.StatusForbidden)
return true
}
Expand All @@ -135,7 +135,7 @@ func (m *AWSAccessMiddleware) HandleRequest(rw http.ResponseWriter, req *http.Re

func (m *AWSAccessMiddleware) handleCommonRequest(rw http.ResponseWriter, req *http.Request) bool {
if err := awsutils.VerifyAWSSignatureV2(req, m.AWSCredentialsProvider); err != nil {
m.Log.WithError(err).Error("AWS signature verification failed.")
m.Log.ErrorContext(req.Context(), "AWS signature verification failed", "error", err)
rw.WriteHeader(http.StatusForbidden)
return true
}
Expand All @@ -150,12 +150,12 @@ func (m *AWSAccessMiddleware) handleRequestByAssumedRole(rw http.ResponseWriter,
)

if err := awsutils.VerifyAWSSignatureV2(req, credentials); err != nil {
m.Log.WithError(err).Error("AWS signature verification failed.")
m.Log.ErrorContext(req.Context(), "AWS signature verification failed", "error", err)
rw.WriteHeader(http.StatusForbidden)
return true
}

m.Log.Debugf("Rewriting headers for AWS request by assumed role %q.", aws.ToString(assumedRole.AssumedRoleUser.Arn))
m.Log.DebugContext(req.Context(), "Rewriting headers for AWS request by assumed role", "assumed_role", aws.ToString(assumedRole.AssumedRoleUser.Arn))

// Add a custom header for marking the special request.
req.Header.Add(appcommon.TeleportAWSAssumedRole, aws.ToString(assumedRole.AssumedRoleUser.Arn))
Expand All @@ -178,7 +178,7 @@ func (m *AWSAccessMiddleware) HandleResponse(response *http.Response) error {

sigV4, err := awsutils.ParseSigV4(authHeader)
if err != nil {
m.Log.WithError(err).Error("Failed to parse AWS request authorization header.")
m.Log.ErrorContext(response.Request.Context(), "Failed to parse AWS request authorization header", "error", err)
return nil
}

Expand All @@ -205,13 +205,13 @@ func (m *AWSAccessMiddleware) handleSTSResponse(response *http.Response) error {
assumedRole, err := unmarshalAssumeRoleResponse(body)
if err != nil {
if !trace.IsNotFound(err) {
m.Log.Warnf("Failed to unmarshal AssumeRoleResponse: %v.", err)
m.Log.WarnContext(response.Request.Context(), "Failed to unmarshal AssumeRoleResponse", "error", err)
}
return nil
}

m.assumedRoles.Store(aws.ToString(assumedRole.Credentials.AccessKeyId), assumedRole)
m.Log.Debugf("Saved credentials for assumed role %q.", aws.ToString(assumedRole.AssumedRoleUser.Arn))
m.Log.DebugContext(response.Request.Context(), "Saved credentials for assumed role", "assumed_role", aws.ToString(assumedRole.AssumedRoleUser.Arn))
return nil
}

Expand Down
12 changes: 6 additions & 6 deletions lib/srv/alpnproxy/azure_msi_middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,13 @@ import (
"crypto"
"encoding/json"
"fmt"
"log/slog"
"net/http"
"sync"
"time"

"github.com/gravitational/trace"
"github.com/jonboulle/clockwork"
"github.com/sirupsen/logrus"

"github.com/gravitational/teleport"
"github.com/gravitational/teleport/api/types"
Expand All @@ -49,7 +49,7 @@ type AzureMSIMiddleware struct {
// Clock is used to override time in tests.
Clock clockwork.Clock
// Log is the Logger.
Log logrus.FieldLogger
Log *slog.Logger
// Secret to be provided by the client.
Secret string

Expand All @@ -65,7 +65,7 @@ func (m *AzureMSIMiddleware) CheckAndSetDefaults() error {
m.Clock = clockwork.NewRealClock()
}
if m.Log == nil {
m.Log = logrus.WithField(teleport.ComponentKey, "azure_msi")
m.Log = slog.With(teleport.ComponentKey, "azure_msi")
}

if m.Secret == "" {
Expand All @@ -86,7 +86,7 @@ func (m *AzureMSIMiddleware) CheckAndSetDefaults() error {
func (m *AzureMSIMiddleware) HandleRequest(rw http.ResponseWriter, req *http.Request) bool {
if req.Host == types.TeleportAzureMSIEndpoint {
if err := m.msiEndpoint(rw, req); err != nil {
m.Log.Warnf("Bad MSI request: %v", err)
m.Log.WarnContext(req.Context(), "Bad MSI request", "error", err)
trace.WriteError(rw, trace.Wrap(err))
}
return true
Expand Down Expand Up @@ -135,7 +135,7 @@ func (m *AzureMSIMiddleware) msiEndpoint(rw http.ResponseWriter, req *http.Reque
// check that msi_res_id matches expected Azure Identity
requestedAzureIdentity := req.Form.Get("msi_res_id")
if requestedAzureIdentity != m.Identity {
m.Log.Warnf("Requested unexpected identity %q, expected %q", requestedAzureIdentity, m.Identity)
m.Log.WarnContext(req.Context(), "Requested unexpected identity", "requested_identity", requestedAzureIdentity, "expected_identity", m.Identity)
return trace.BadParameter("unexpected value for parameter 'msi_res_id': %v", requestedAzureIdentity)
}

Expand All @@ -144,7 +144,7 @@ func (m *AzureMSIMiddleware) msiEndpoint(rw http.ResponseWriter, req *http.Reque
return trace.Wrap(err)
}

m.Log.Infof("MSI: returning token for identity %v", m.Identity)
m.Log.InfoContext(req.Context(), "MSI: returning token for identity", "identity", m.Identity)

rw.Header().Add("Content-Type", "application/json; charset=utf-8")
rw.Header().Add("Content-Length", fmt.Sprintf("%v", len(respBody)))
Expand Down
Loading

0 comments on commit 41bae13

Please sign in to comment.