Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(auth): pass logger from auth layer to metadata package #11288

Merged
merged 2 commits into from
Dec 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 8 additions & 5 deletions auth/credentials/compute.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,11 @@ var (

// computeTokenProvider creates a [cloud.google.com/go/auth.TokenProvider] that
// uses the metadata service to retrieve tokens.
func computeTokenProvider(opts *DetectOptions) auth.TokenProvider {
return auth.NewCachedTokenProvider(computeProvider{scopes: opts.Scopes}, &auth.CachedTokenProviderOptions{
func computeTokenProvider(opts *DetectOptions, client *metadata.Client) auth.TokenProvider {
return auth.NewCachedTokenProvider(&computeProvider{
scopes: opts.Scopes,
client: client,
}, &auth.CachedTokenProviderOptions{
ExpireEarly: opts.EarlyTokenRefresh,
DisableAsyncRefresh: opts.DisableAsyncRefresh,
})
Expand All @@ -47,6 +50,7 @@ func computeTokenProvider(opts *DetectOptions) auth.TokenProvider {
// computeProvider fetches tokens from the google cloud metadata service.
type computeProvider struct {
scopes []string
client *metadata.Client
}

type metadataTokenResp struct {
Expand All @@ -55,7 +59,7 @@ type metadataTokenResp struct {
TokenType string `json:"token_type"`
}

func (cs computeProvider) Token(ctx context.Context) (*auth.Token, error) {
func (cs *computeProvider) Token(ctx context.Context) (*auth.Token, error) {
tokenURI, err := url.Parse(computeTokenURI)
if err != nil {
return nil, err
Expand All @@ -65,8 +69,7 @@ func (cs computeProvider) Token(ctx context.Context) (*auth.Token, error) {
v.Set("scopes", strings.Join(cs.scopes, ","))
tokenURI.RawQuery = v.Encode()
}
// TODO(codyoss): create a metadata client and plumb through logger
tokenJSON, err := metadata.GetWithContext(ctx, tokenURI.String())
tokenJSON, err := cs.client.GetWithContext(ctx, tokenURI.String())
if err != nil {
return nil, fmt.Errorf("credentials: cannot fetch token: %w", err)
}
Expand Down
6 changes: 5 additions & 1 deletion auth/credentials/compute_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ import (
"net/http/httptest"
"strings"
"testing"

"cloud.google.com/go/compute/metadata"
)

const computeMetadataEnvVar = "GCE_METADATA_HOST"
Expand All @@ -42,7 +44,9 @@ func TestComputeTokenProvider(t *testing.T) {
Scopes: []string{
scope,
},
})
},
metadata.NewClient(nil),
)
tok, err := tp.Token(context.Background())
if err != nil {
t.Fatal(err)
Expand Down
11 changes: 8 additions & 3 deletions auth/credentials/detect.go
Original file line number Diff line number Diff line change
Expand Up @@ -98,12 +98,17 @@ func DetectDefault(opts *DetectOptions) (*auth.Credentials, error) {
}

if OnGCE() {
metadataClient := metadata.NewWithOptions(&metadata.Options{
Logger: opts.logger(),
})
return auth.NewCredentials(&auth.CredentialsOptions{
TokenProvider: computeTokenProvider(opts),
TokenProvider: computeTokenProvider(opts, metadataClient),
ProjectIDProvider: auth.CredentialsPropertyFunc(func(ctx context.Context) (string, error) {
return metadata.ProjectIDWithContext(ctx)
return metadataClient.ProjectIDWithContext(ctx)
}),
UniverseDomainProvider: &internal.ComputeUniverseDomainProvider{},
UniverseDomainProvider: &internal.ComputeUniverseDomainProvider{
MetadataClient: metadataClient,
},
}), nil
}

Expand Down
20 changes: 12 additions & 8 deletions auth/credentials/idtoken/compute.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import (
"cloud.google.com/go/auth"
"cloud.google.com/go/auth/internal"
"cloud.google.com/go/compute/metadata"
"github.com/googleapis/gax-go/v2/internallog"
)

const identitySuffix = "instance/service-accounts/default/identity"
Expand All @@ -34,31 +35,34 @@ func computeCredentials(opts *Options) (*auth.Credentials, error) {
if opts.CustomClaims != nil {
return nil, fmt.Errorf("idtoken: Options.CustomClaims can't be used with the metadata service, please provide a service account if you would like to use this feature")
}
tp := computeIDTokenProvider{
metadataClient := metadata.NewWithOptions(&metadata.Options{
Logger: internallog.New(opts.Logger),
})
tp := &computeIDTokenProvider{
audience: opts.Audience,
format: opts.ComputeTokenFormat,
// TODO(codyoss): connect logger here after metadata options are
// available.
client: *metadata.NewClient(opts.client()),
client: metadataClient,
}
return auth.NewCredentials(&auth.CredentialsOptions{
TokenProvider: auth.NewCachedTokenProvider(tp, &auth.CachedTokenProviderOptions{
ExpireEarly: 5 * time.Minute,
}),
ProjectIDProvider: auth.CredentialsPropertyFunc(func(ctx context.Context) (string, error) {
return metadata.ProjectIDWithContext(ctx)
return metadataClient.ProjectIDWithContext(ctx)
}),
UniverseDomainProvider: &internal.ComputeUniverseDomainProvider{},
UniverseDomainProvider: &internal.ComputeUniverseDomainProvider{
MetadataClient: metadataClient,
},
}), nil
}

type computeIDTokenProvider struct {
audience string
format ComputeTokenFormat
client metadata.Client
client *metadata.Client
}

func (c computeIDTokenProvider) Token(ctx context.Context) (*auth.Token, error) {
func (c *computeIDTokenProvider) Token(ctx context.Context) (*auth.Token, error) {
v := url.Values{}
v.Set("audience", c.audience)
if c.format != ComputeTokenFormatStandard {
Expand Down
2 changes: 1 addition & 1 deletion auth/go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ module cloud.google.com/go/auth
go 1.21

require (
cloud.google.com/go/compute/metadata v0.5.2
cloud.google.com/go/compute/metadata v0.6.0
github.com/google/go-cmp v0.6.0
github.com/google/s2a-go v0.1.8
github.com/googleapis/enterprise-certificate-proxy v0.3.4
Expand Down
4 changes: 2 additions & 2 deletions auth/go.sum
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
cloud.google.com/go/compute/metadata v0.5.2 h1:UxK4uu/Tn+I3p2dYWTfiX4wva7aYlKixAHn3fyqngqo=
cloud.google.com/go/compute/metadata v0.5.2/go.mod h1:C66sj2AluDcIqakBq/M8lw8/ybHgOZqin2obFxa/E5k=
cloud.google.com/go/compute/metadata v0.6.0 h1:A6hENjEsCDtC1k8byVsgwvVcioamEHvZ4j01OwKxG9I=
cloud.google.com/go/compute/metadata v0.6.0/go.mod h1:FjyFAW1MW0C203CEOMDTu3Dk1FlqW3Rga40jzHL4hfg=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/felixge/httpsnoop v1.0.4 h1:NFTV2Zj1bL4mc9sqWACXbQFVBBg2W3GPvqp8/ESS2Wg=
Expand Down
9 changes: 9 additions & 0 deletions auth/grpctransport/grpctransport.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ import (
"cloud.google.com/go/auth/credentials"
"cloud.google.com/go/auth/internal"
"cloud.google.com/go/auth/internal/transport"
"github.com/googleapis/gax-go/v2/internallog"
"go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc"
"google.golang.org/grpc"
grpccreds "google.golang.org/grpc/credentials"
Expand Down Expand Up @@ -137,6 +138,10 @@ func (o *Options) client() *http.Client {
return nil
}

func (o *Options) logger() *slog.Logger {
return internallog.New(o.Logger)
}

func (o *Options) validate() error {
if o == nil {
return errors.New("grpctransport: opts required to be non-nil")
Expand Down Expand Up @@ -178,6 +183,9 @@ func (o *Options) resolveDetectOptions() *credentials.DetectOptions {
do.Client = transport.DefaultHTTPClientWithTLS(tlsConfig)
do.TokenURL = credentials.GoogleMTLSTokenURL
}
if do.Logger == nil {
do.Logger = o.logger()
}
return do
}

Expand Down Expand Up @@ -246,6 +254,7 @@ func dial(ctx context.Context, secure bool, opts *Options) (*grpc.ClientConn, er
ClientCertProvider: opts.ClientCertProvider,
Client: opts.client(),
UniverseDomain: opts.UniverseDomain,
Logger: opts.logger(),
}
if io := opts.InternalOptions; io != nil {
tOpts.DefaultEndpointTemplate = io.DefaultEndpointTemplate
Expand Down
3 changes: 2 additions & 1 deletion auth/grpctransport/grpctransport_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import (
echo "cloud.google.com/go/auth/grpctransport/testdata"
"cloud.google.com/go/auth/internal"
"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials/insecure"
"google.golang.org/grpc/metadata"
Expand Down Expand Up @@ -258,7 +259,7 @@ func TestOptions_ResolveDetectOptions(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := tt.in.resolveDetectOptions()
if diff := cmp.Diff(tt.want, got); diff != "" {
if diff := cmp.Diff(tt.want, got, cmpopts.IgnoreFields(credentials.DetectOptions{}, "Logger")); diff != "" {
t.Errorf("mismatch (-want +got):\n%s", diff)
}
})
Expand Down
9 changes: 9 additions & 0 deletions auth/httptransport/httptransport.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import (
detect "cloud.google.com/go/auth/credentials"
"cloud.google.com/go/auth/internal"
"cloud.google.com/go/auth/internal/transport"
"github.com/googleapis/gax-go/v2/internallog"
)

// ClientCertProvider is a function that returns a TLS client certificate to be
Expand Down Expand Up @@ -107,6 +108,10 @@ func (o *Options) client() *http.Client {
return nil
}

func (o *Options) logger() *slog.Logger {
return internallog.New(o.Logger)
}

func (o *Options) resolveDetectOptions() *detect.DetectOptions {
io := o.InternalOptions
// soft-clone these so we are not updating a ref the user holds and may reuse
Expand All @@ -131,6 +136,9 @@ func (o *Options) resolveDetectOptions() *detect.DetectOptions {
do.Client = transport.DefaultHTTPClientWithTLS(tlsConfig)
do.TokenURL = detect.GoogleMTLSTokenURL
}
if do.Logger == nil {
do.Logger = o.logger()
}
return do
}

Expand Down Expand Up @@ -203,6 +211,7 @@ func NewClient(opts *Options) (*http.Client, error) {
ClientCertProvider: opts.ClientCertProvider,
Client: opts.client(),
UniverseDomain: opts.UniverseDomain,
Logger: opts.logger(),
}
if io := opts.InternalOptions; io != nil {
tOpts.DefaultEndpointTemplate = io.DefaultEndpointTemplate
Expand Down
3 changes: 2 additions & 1 deletion auth/httptransport/httptransport_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import (
"cloud.google.com/go/auth/credentials"
"cloud.google.com/go/auth/internal"
"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
)

func TestAddAuthorizationMiddleware(t *testing.T) {
Expand Down Expand Up @@ -296,7 +297,7 @@ func TestOptions_ResolveDetectOptions(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := tt.in.resolveDetectOptions()
if diff := cmp.Diff(tt.want, got); diff != "" {
if diff := cmp.Diff(tt.want, got, cmpopts.IgnoreFields(credentials.DetectOptions{}, "Logger")); diff != "" {
t.Errorf("mismatch (-want +got):\n%s", diff)
}
})
Expand Down
11 changes: 6 additions & 5 deletions auth/internal/internal.go
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,7 @@ func (p StaticProperty) GetProperty(context.Context) (string, error) {
// ComputeUniverseDomainProvider fetches the credentials universe domain from
// the google cloud metadata service.
type ComputeUniverseDomainProvider struct {
MetadataClient *metadata.Client
universeDomainOnce sync.Once
universeDomain string
universeDomainErr error
Expand All @@ -190,7 +191,7 @@ type ComputeUniverseDomainProvider struct {
// metadata service.
func (c *ComputeUniverseDomainProvider) GetProperty(ctx context.Context) (string, error) {
c.universeDomainOnce.Do(func() {
c.universeDomain, c.universeDomainErr = getMetadataUniverseDomain(ctx)
c.universeDomain, c.universeDomainErr = getMetadataUniverseDomain(ctx, c.MetadataClient)
})
if c.universeDomainErr != nil {
return "", c.universeDomainErr
Expand All @@ -199,14 +200,14 @@ func (c *ComputeUniverseDomainProvider) GetProperty(ctx context.Context) (string
}

// httpGetMetadataUniverseDomain is a package var for unit test substitution.
var httpGetMetadataUniverseDomain = func(ctx context.Context) (string, error) {
var httpGetMetadataUniverseDomain = func(ctx context.Context, client *metadata.Client) (string, error) {
ctx, cancel := context.WithTimeout(ctx, 1*time.Second)
defer cancel()
return metadata.GetWithContext(ctx, "universe/universe-domain")
return client.GetWithContext(ctx, "universe/universe-domain")
}

func getMetadataUniverseDomain(ctx context.Context) (string, error) {
universeDomain, err := httpGetMetadataUniverseDomain(ctx)
func getMetadataUniverseDomain(ctx context.Context, client *metadata.Client) (string, error) {
universeDomain, err := httpGetMetadataUniverseDomain(ctx, client)
if err == nil {
return universeDomain, nil
}
Expand Down
8 changes: 4 additions & 4 deletions auth/internal/internal_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,29 +28,29 @@ func TestComputeUniverseDomainProvider(t *testing.T) {
notDefinedError := metadata.NotDefinedError("universe/universe_domain")
testCases := []struct {
name string
getFunc func(ctx context.Context) (string, error)
getFunc func(context.Context, *metadata.Client) (string, error)
want string
wantErr error
}{
{
name: "test error",
getFunc: func(ctx context.Context) (string, error) {
getFunc: func(context.Context, *metadata.Client) (string, error) {
return "", fatalErr
},
want: "",
wantErr: fatalErr,
},
{
name: "test error 404",
getFunc: func(ctx context.Context) (string, error) {
getFunc: func(context.Context, *metadata.Client) (string, error) {
return "", notDefinedError
},
want: DefaultUniverseDomain,
wantErr: nil,
},
{
name: "test valid",
getFunc: func(ctx context.Context) (string, error) {
getFunc: func(context.Context, *metadata.Client) (string, error) {
return "example.com", nil
},
want: "example.com",
Expand Down
6 changes: 4 additions & 2 deletions auth/internal/transport/cba.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (
"crypto/x509"
"errors"
"log"
"log/slog"
"net"
"net/http"
"net/url"
Expand Down Expand Up @@ -63,6 +64,7 @@ type Options struct {
UniverseDomain string
EnableDirectPath bool
EnableDirectPathXds bool
Logger *slog.Logger
}

// getUniverseDomain returns the default service domain for a given Cloud
Expand Down Expand Up @@ -263,8 +265,8 @@ func getTransportConfig(opts *Options) (*transportConfig, error) {
return &defaultTransportConfig, nil
}

s2aAddress := GetS2AAddress()
mtlsS2AAddress := GetMTLSS2AAddress()
s2aAddress := GetS2AAddress(opts.Logger)
mtlsS2AAddress := GetMTLSS2AAddress(opts.Logger)
if s2aAddress == "" && mtlsS2AAddress == "" {
return &defaultTransportConfig, nil
}
Expand Down
Loading
Loading