diff --git a/oidc/pkce_verifier.go b/oidc/pkce_verifier.go index 04d3490..bf20341 100644 --- a/oidc/pkce_verifier.go +++ b/oidc/pkce_verifier.go @@ -7,6 +7,7 @@ import ( "crypto/sha256" "encoding/base64" "fmt" + "regexp" "github.com/hashicorp/cap/oidc/internal/base62" ) @@ -51,19 +52,39 @@ type S256Verifier struct { } // min len of 43 chars per https://tools.ietf.org/html/rfc7636#section-4.1 -const verifierLen = 43 +const ( + // min len of 43 chars per https://tools.ietf.org/html/rfc7636#section-4.1 + minVerifierLen = 43 + // max len of 128 chars per https://tools.ietf.org/html/rfc7636#section-4.1 + maxVerifierLen = 128 +) // NewCodeVerifier creates a new CodeVerifier (*S256Verifier). +// Supported options: WithVerifier // // See: https://tools.ietf.org/html/rfc7636#section-4.1 -func NewCodeVerifier() (*S256Verifier, error) { +func NewCodeVerifier(opt ...Option) (*S256Verifier, error) { const op = "NewCodeVerifier" - data, err := base62.Random(verifierLen) - if err != nil { - return nil, fmt.Errorf("%s: unable to create verifier data %w", op, err) + var ( + err error + verifierData string + ) + opts := getPKCEOpts(opt...) + switch { + case opts.withVerifier != "": + verifierData = opts.withVerifier + default: + var err error + verifierData, err = base62.Random(minVerifierLen) + if err != nil { + return nil, fmt.Errorf("%s: unable to create verifier data %w", op, err) + } + } + if err := verifierIsValid(verifierData); err != nil { + return nil, fmt.Errorf("%s: %w", op, err) } v := &S256Verifier{ - verifier: data, // no need to encode it, since bas62.Random uses a limited set of characters. + verifier: verifierData, // no need to encode it, since bas62.Random uses a limited set of characters. method: S256, } if v.challenge, err = CreateCodeChallenge(v); err != nil { @@ -72,6 +93,25 @@ func NewCodeVerifier() (*S256Verifier, error) { return v, nil } +func verifierIsValid(v string) error { + const op = "verifierIsValid" + switch { + case len(v) < minVerifierLen: + return fmt.Errorf("%s: verifier length is less than %d", op, minVerifierLen) + case len(v) > maxVerifierLen: + return fmt.Errorf("%s: verifier length is greater than %d", op, maxVerifierLen) + default: + // check that the verifier is valid based on + // https://datatracker.ietf.org/doc/html/rfc7636#section-4.1 + // Check for valid characters: A-Z, a-z, 0-9, -, _, ., ~ + validChars := regexp.MustCompile(`^[A-Za-z0-9\-\._~]+$`) + if !validChars.MatchString(v) { + return fmt.Errorf("%s: verifier contains invalid characters", op) + } + } + return nil +} + func (v *S256Verifier) Verifier() string { return v.verifier } // Verifier implements the CodeVerifier.Verifier() interface function. func (v *S256Verifier) Challenge() string { return v.challenge } // Challenge implements the CodeVerifier.Challenge() interface function. func (v *S256Verifier) Method() ChallengeMethod { return v.method } // Method implements the CodeVerifier.Method() interface function. @@ -99,3 +139,35 @@ func CreateCodeChallenge(v CodeVerifier) (string, error) { sum := h.Sum(nil) return base64.RawURLEncoding.EncodeToString(sum), nil } + +// pkceOptions is the set of available options. +type pkceOptions struct { + withVerifier string +} + +// pkceDefaults is a handy way to get the defaults at runtime and +// during unit tests. +func pkceDefaults() pkceOptions { + return pkceOptions{} +} + +// getPKCEOpts gets the defaults and applies the opt overrides passed in. +func getPKCEOpts(opt ...Option) pkceOptions { + opts := pkceDefaults() + ApplyOpts(&opts, opt...) + return opts +} + +// WithVerifier provides an optional verifier for the code verifier. When this +// option is provided, NewCodeVerifier will use the provided verifier. Note the +// verifier must use the base62 character set. +// See: https://datatracker.ietf.org/doc/html/rfc7636#section-4.1 +// +// Valid for: NewVerifier +func WithVerifier(verifier string) Option { + return func(o interface{}) { + if o, ok := o.(*pkceOptions); ok { + o.withVerifier = verifier + } + } +} diff --git a/oidc/pkce_verifier_test.go b/oidc/pkce_verifier_test.go index a406a89..e8e2853 100644 --- a/oidc/pkce_verifier_test.go +++ b/oidc/pkce_verifier_test.go @@ -9,6 +9,7 @@ import ( "errors" "testing" + "github.com/hashicorp/cap/oidc/internal/base62" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -18,7 +19,7 @@ func TestNewCodeVerifier(t *testing.T) { assert, require := assert.New(t), require.New(t) got, err := NewCodeVerifier() require.NoError(err) - assert.Equal(verifierLen, len(got.verifier)) + assert.Equal(minVerifierLen, len(got.verifier)) assert.Equal(S256, got.Method()) challenge, err := CreateCodeChallenge(got) @@ -53,3 +54,35 @@ func TestCreateCodeChallenge(t *testing.T) { assert.True(errors.Is(err, ErrUnsupportedChallengeMethod)) }) } + +func Test_WithVerifier(t *testing.T) { + t.Parallel() + assert, require := assert.New(t), require.New(t) + v, err := base62.Random(43) + require.NoError(err) + got, err := NewCodeVerifier(WithVerifier(v)) + require.NoError(err) + assert.Equal(v, got.Verifier()) + + // Test that the verifier is too short + v, err = base62.Random(42) + require.NoError(err) + _, err = NewCodeVerifier(WithVerifier(v)) + require.Error(err) + assert.Contains(err.Error(), "verifier length is less than 43") + + // Test that the verifier is too long + v, err = base62.Random(129) + require.NoError(err) + _, err = NewCodeVerifier(WithVerifier(v)) + require.Error(err) + assert.Contains(err.Error(), "verifier length is greater than 128") + + // Test that the verifier contains invalid characters + v, err = base62.Random(43) + require.NoError(err) + v = v + "!" + _, err = NewCodeVerifier(WithVerifier(v)) + require.Error(err) + assert.Contains(err.Error(), "verifier contains invalid characters") +}