From 09a526c4b48c92c03173116c5aeb3e6654da958b Mon Sep 17 00:00:00 2001 From: Jim Date: Thu, 1 Aug 2024 14:09:14 -0400 Subject: [PATCH] fixup! feat (config): add support for a http.RoundTripper --- oidc/config_test.go | 141 +++++++++++++++++++++++++++++++++++++++++- oidc/provider_test.go | 25 ++++++++ 2 files changed, 165 insertions(+), 1 deletion(-) diff --git a/oidc/config_test.go b/oidc/config_test.go index 4dd3b00..85b576f 100644 --- a/oidc/config_test.go +++ b/oidc/config_test.go @@ -7,6 +7,7 @@ import ( "crypto/x509" "errors" "fmt" + "net/http" "testing" "time" @@ -44,6 +45,8 @@ func TestNewConfig(t *testing.T) { return time.Now().Add(-1 * time.Minute) } + testRt := newTestRoundTripper(t) + type args struct { issuer string clientID string @@ -61,7 +64,7 @@ func TestNewConfig(t *testing.T) { wantErrContains string }{ { - name: "valid-with-all-valid-opts", + name: "valid-with-all-valid-opts-except-with-round-tripper", args: args{ issuer: "http://your_issuer/", clientID: "your_client_id", @@ -103,6 +106,49 @@ func TestNewConfig(t *testing.T) { }, }, }, + { + name: "with-round-tripper", + args: args{ + issuer: "http://your_issuer/", + clientID: "your_client_id", + clientSecret: "your_client_secret", + supported: []Alg{RS512}, + allowedRedirectURLs: []string{"http://your_redirect_url", "http://redirect_url_two", "http://redirect_url_three"}, + opt: []Option{ + WithAudiences("your_aud1", "your_aud2"), + WithScopes("email", "profile"), + WithRoundTripper(testRt), + WithNow(testNow), + WithProviderConfig(&ProviderConfig{ + AuthURL: "https://auth-endpoint", + JWKSURL: "https://jwks-endpoint", + TokenURL: "https://token-endpoint", + UserInfoURL: "https://userinfo-endpoint", + }), + }, + }, + want: &Config{ + Issuer: "http://your_issuer/", + ClientID: "your_client_id", + ClientSecret: "your_client_secret", + SupportedSigningAlgs: []Alg{RS512}, + Audiences: []string{"your_aud1", "your_aud2"}, + Scopes: []string{oidc.ScopeOpenID, "email", "profile"}, + RoundTripper: testRt, + NowFunc: testNow, + AllowedRedirectURLs: []string{ + "http://your_redirect_url", + "http://redirect_url_two", + "http://redirect_url_three", + }, + ProviderConfig: &ProviderConfig{ + AuthURL: "https://auth-endpoint", + JWKSURL: "https://jwks-endpoint", + TokenURL: "https://token-endpoint", + UserInfoURL: "https://userinfo-endpoint", + }, + }, + }, { name: "missing-provider-config-auth-url", args: args{ @@ -282,6 +328,22 @@ func TestNewConfig(t *testing.T) { wantErr: true, wantIsErr: ErrInvalidCACert, }, + { + name: "invalid-both-cert-and-round-tripper", + args: args{ + issuer: "http://your_issuer/", + clientID: "your_client_id", + clientSecret: "your_client_secret", + supported: []Alg{RS512}, + allowedRedirectURLs: []string{"http://your_redirect_url"}, + opt: []Option{ + WithProviderCA(testCaPem), + WithRoundTripper(testRt), + }, + }, + wantErr: true, + wantIsErr: ErrInvalidParameter, + }, { name: "invalid-alg", args: args{ @@ -430,6 +492,7 @@ func TestConfig_Hash(t *testing.T) { require.NoError(t, err) return c } + testRt := newTestRoundTripper(t) tests := []struct { name string c1 *Config @@ -473,6 +536,42 @@ func TestConfig_Hash(t *testing.T) { ), wantEqual: true, }, + { + name: "equal-with-round-tripper", + c1: newCfg( + "https://www.alice.com", + "client-id", "client-secret", + []Alg{RS256}, + []string{"www.alice.com/callback", "www.bob.com/callback"}, + WithScopes("email", "profile"), + WithAudiences("alice.com", "bob.com"), + WithRoundTripper(testRt), + WithNow(time.Now), + WithProviderConfig(&ProviderConfig{ + AuthURL: "https://auth-endpoint", + JWKSURL: "https://jwks-endpoint", + TokenURL: "https://token-endpoint", + UserInfoURL: "https://userinfo-endpoint", + }), + ), + c2: newCfg( + "https://www.alice.com", + "client-id", "client-secret", + []Alg{RS256}, + []string{"www.bob.com/callback", "www.alice.com/callback"}, + WithScopes("profile", "email"), + WithAudiences("bob.com", "alice.com"), + WithRoundTripper(testRt), + WithNow(time.Now), + WithProviderConfig(&ProviderConfig{ + AuthURL: "https://auth-endpoint", + JWKSURL: "https://jwks-endpoint", + TokenURL: "https://token-endpoint", + UserInfoURL: "https://userinfo-endpoint", + }), + ), + wantEqual: true, + }, { name: "diff-issuer", c1: newCfg( @@ -664,6 +763,29 @@ func TestConfig_Hash(t *testing.T) { ), wantEqual: false, }, + { + name: "diff-round-trippers", + c1: newCfg( + "https://www.alice.com", + "client-id", "client-secret", + []Alg{RS256}, + []string{"www.alice.com/callback"}, + WithScopes("email", "profile"), + WithAudiences("alice.com", "bob.com"), + WithRoundTripper(newTestRoundTripper(t)), + WithNow(time.Now), + ), + c2: newCfg( + "https://www.alice.com", + "client-id", "client-secret", + []Alg{RS256}, + []string{"www.alice.com/callback"}, + WithScopes("email", "profile"), + WithAudiences("alice.com", "bob.com"), + WithNow(time.Now), + ), + wantEqual: false, + }, { name: "diff-now-func", c1: newCfg( @@ -855,3 +977,20 @@ func TestConfig_Hash(t *testing.T) { }) } } + +type testRoundTripper struct { + transport http.RoundTripper + called int +} + +func newTestRoundTripper(t *testing.T) *testRoundTripper { + t.Helper() + return &testRoundTripper{ + transport: http.DefaultTransport, + } +} + +func (rt *testRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { + rt.called++ + return rt.transport.RoundTrip(req) +} diff --git a/oidc/provider_test.go b/oidc/provider_test.go index b715863..8dca1dd 100644 --- a/oidc/provider_test.go +++ b/oidc/provider_test.go @@ -714,6 +714,31 @@ func TestHTTPClient(t *testing.T) { require.NoError(t, err) assert.Equal(t, c.Transport, p.client.Transport) }) + t.Run("check-transport-with-round-tripper", func(t *testing.T) { + testRt := newTestRoundTripper(t) + p := &Provider{ + config: &Config{ + RoundTripper: testRt, + }, + } + c, err := p.HTTPClient() + require.NoError(t, err) + assert.Equal(t, c.Transport, p.client.Transport) + }) + t.Run("err-both-ca-and-round-trippe", func(t *testing.T) { + _, testCaPem := TestGenerateCA(t, []string{"localhost"}) + + p := &Provider{ + config: &Config{ + ProviderCA: testCaPem, + RoundTripper: newTestRoundTripper(t), + }, + } + _, err := p.HTTPClient() + require.Error(t, err) + assert.ErrorIs(t, err, ErrInvalidParameter) + assert.ErrorContains(t, err, "you cannot specify config for both a ProviderCA and Transport") + }) } func TestProvider_UserInfo(t *testing.T) {