Skip to content

Commit

Permalink
ccl/oidcccl: support principal matching on list claims
Browse files Browse the repository at this point in the history
Previously, matching on ID token claims was not possible if the claim key
specified had a corresponding value that was a list, not a
string. With this change, matching can now occur on claims that are list valued
in order to add login capabilities to DB Console. It is important to note that
this change does NOT offer the user the ability to choose between possible
matches; it simply selects the first match to log the user in.

This change also adds more verbose logging about ID token details.

Epic: none
Fixes: cockroachdb#97301, cockroachdb#97468

Release note (general change): Increasing the logging verbosity
is more helpful with troubleshooting DB Console SSO issues.
  • Loading branch information
cameronnunez committed Mar 14, 2023
1 parent 0205bfe commit 1d20cae
Show file tree
Hide file tree
Showing 4 changed files with 173 additions and 16 deletions.
1 change: 1 addition & 0 deletions pkg/ccl/oidcccl/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ go_library(
name = "oidcccl",
srcs = [
"authentication_oidc.go",
"claim_match.go",
"settings.go",
"state.go",
],
Expand Down
38 changes: 22 additions & 16 deletions pkg/ccl/oidcccl/authentication_oidc.go
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,10 @@ func reloadConfigLocked(

provider, err := oidc.NewProvider(ctx, server.conf.providerURL)
if err != nil {
log.Warningf(ctx, "unable to initialize OIDC provider, disabling OIDC: %v", err)
log.Warningf(ctx, "unable to initialize OIDC server, disabling OIDC: %v", err)
if log.V(1) {
log.Infof(ctx, "check provider URL OIDC cluster setting: "+OIDCProviderURLSettingName)
}
return
}

Expand All @@ -208,7 +211,10 @@ func reloadConfigLocked(

redirectURL, err := getRegionSpecificRedirectURL(locality, server.conf.redirectURLConf)
if err != nil {
log.Warningf(ctx, "unable to initialize OIDC provider, disabling OIDC: %v", err)
log.Warningf(ctx, "unable to initialize OIDC server, disabling OIDC: %v", err)
if log.V(1) {
log.Infof(ctx, "check redirect URL OIDC cluster setting: "+OIDCRedirectURLSettingName)
}
return
}

Expand Down Expand Up @@ -312,16 +318,16 @@ var ConfigureOIDC = func(
return
}

oauth2Token, err := oidcAuthentication.oauth2Config.Exchange(ctx, r.URL.Query().Get(codeKey))
credentials, err := oidcAuthentication.oauth2Config.Exchange(ctx, r.URL.Query().Get(codeKey))
if err != nil {
log.Errorf(ctx, "OIDC: failed to exchange code for token: %v", err)
http.Error(w, genericCallbackHTTPError, http.StatusInternalServerError)
return
}

rawIDToken, ok := oauth2Token.Extra(idTokenKey).(string)
rawIDToken, ok := credentials.Extra(idTokenKey).(string)
if !ok {
log.Error(ctx, "OIDC: failed to extract ID token from OAuth2 token")
log.Error(ctx, "OIDC: failed to extract ID token from the token credentials")
http.Error(w, genericCallbackHTTPError, http.StatusInternalServerError)
return
}
Expand All @@ -340,23 +346,23 @@ var ConfigureOIDC = func(
return
}

var principal string
claim := claims[oidcAuthentication.conf.claimJSONKey]
if err := json.Unmarshal(claim, &principal); err != nil {
log.Errorf(ctx, "OIDC: failed to complete authentication: failed to extract claim key %s: %v", oidcAuthentication.conf.claimJSONKey, err)
http.Error(w, genericCallbackHTTPError, http.StatusInternalServerError)
return
if log.V(1) {
log.Infof(
ctx,
"attempting to extract SQL username from the payload using the claim key %s and regex %s",
oidcAuthentication.conf.claimJSONKey,
oidcAuthentication.conf.principalRegex,
)
}

match := oidcAuthentication.conf.principalRegex.FindStringSubmatch(principal)
numGroups := len(match)
if numGroups != 2 {
log.Errorf(ctx, "OIDC: failed to complete authentication: expected one group in regexp, got %d", numGroups)
username, err := extractUsernameFromClaims(
ctx, claims, oidcAuthentication.conf.claimJSONKey, oidcAuthentication.conf.principalRegex,
)
if err != nil {
http.Error(w, genericCallbackHTTPError, http.StatusInternalServerError)
return
}

username := match[1]
cookie, err := userLoginFromSSO(ctx, username)
if err != nil {
log.Errorf(ctx, "OIDC: failed to complete authentication: unable to create session for %s: %v", username, err)
Expand Down
63 changes: 63 additions & 0 deletions pkg/ccl/oidcccl/authentication_oidc_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,12 @@ import (
"crypto/hmac"
"crypto/sha256"
"encoding/base64"
"encoding/json"
"fmt"
"net/http"
"net/http/httptest"
"net/url"
"regexp"
"strings"
"testing"

Expand Down Expand Up @@ -289,6 +291,67 @@ func TestOIDCStateEncodeDecode(t *testing.T) {
}
}

func TestOIDCClaimMatch(t *testing.T) {
ctx := context.Background()

for _, tc := range []struct {
testName string
claimKey string
principalRegex string
claims map[string]json.RawMessage
wantError bool
}{
{
testName: "string valued claim",
claimKey: "email",
principalRegex: "^([^@]+)@[^@]+$",
claims: map[string]json.RawMessage{
"email": json.RawMessage(`"[email protected]"`),
},
},
{
testName: "string valued claim with no match",
claimKey: "email",
principalRegex: "^([^@]+)@[^@]+$",
claims: map[string]json.RawMessage{
"email": json.RawMessage(`"bademail"`),
},
wantError: true,
},
{
testName: "list valued claim",
claimKey: "groups",
principalRegex: "^([^@]+)@[^@]+$",
claims: map[string]json.RawMessage{
"groups": json.RawMessage(
`["badgroupname", "[email protected]", "anotherbadgroupname"]`,
),
},
},
{
testName: "list valued claim with no matches",
claimKey: "groups",
principalRegex: "^([^@]+)@[^@]+$",
claims: map[string]json.RawMessage{
"groups": json.RawMessage(`["badgroupname", "anotherbadgroupname"]`),
},
wantError: true,
},
} {
t.Run(tc.testName, func(t *testing.T) {
sqlUsername, err := extractUsernameFromClaims(
ctx, tc.claims, tc.claimKey, regexp.MustCompile(tc.principalRegex),
)
if !tc.wantError {
require.NoError(t, err)
require.Equal(t, "myfakeemail", sqlUsername)
} else {
require.ErrorContains(t, err, "expected one group in regexp")
}
})
}
}

func Test_getRegionSpecificRedirectURL(t *testing.T) {
type args struct {
locality roachpb.Locality
Expand Down
87 changes: 87 additions & 0 deletions pkg/ccl/oidcccl/claim_match.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
package oidcccl

import (
"context"
"encoding/json"
"regexp"

"github.com/cockroachdb/cockroach/pkg/util/log"
"github.com/cockroachdb/errors"
)

// extractUsernameFromClaims uses a regex to strip out elements of the value
// corresponding to the token claim claimKey.
func extractUsernameFromClaims(
ctx context.Context,
claims map[string]json.RawMessage,
claimKey string,
principalRE *regexp.Regexp,
) (string, error) {
targetClaim, ok := claims[claimKey]
if !ok {
log.Errorf(
ctx, "OIDC: failed to complete authentication: invalid JSON claim key: %s", claimKey,
)
}

var principal string
if err := json.Unmarshal(targetClaim, &principal); err != nil {
// Try parsing assuming the claim value is a list and not a string.
var principals []string
if err = json.Unmarshal(targetClaim, &principals); err != nil {
log.Errorf(ctx,
"OIDC: failed to complete authentication: failed to parse value for the claim %s: %v",
claimKey, err,
)
return "", err
}
return matchOnListClaim(ctx, principals, principalRE)
}

match := principalRE.FindStringSubmatch(principal)
numGroups := len(match)
if numGroups != 2 {
err := errors.Newf("expected one group in regexp, got %d", numGroups)
log.Errorf(ctx, "OIDC: failed to complete authentication: %v", err)
if log.V(1) {
log.Infof(ctx,
"check OIDC cluster settings: %s, %s",
OIDCPrincipalRegexSettingName, OIDCClaimJSONKeySettingName,
)
}
return "", err
}

return match[1], nil
}

func matchOnListClaim(
ctx context.Context, principals []string, principalRE *regexp.Regexp,
) (string, error) {
// This is the case where the claim key specified is the "groups" claim.
// The first matching principal is selected as the SQL username.
if log.V(1) {
log.Infof(ctx,
"multiple principals in the claim found; selecting first matching principal...",
)
}

var match []string
for _, principal := range principals {
match = principalRE.FindStringSubmatch(principal)
if len(match) == 2 {
return match[1], nil
}
}

// Error when there is not a match.
err := errors.Newf("expected one group in regexp")
log.Errorf(ctx, "OIDC: failed to complete authentication: %v", err)
if log.V(1) {
log.Infof(ctx,
"check OIDC cluster settings: %s, %s",
OIDCPrincipalRegexSettingName, OIDCClaimJSONKeySettingName,
)
}
return "", err
}

0 comments on commit 1d20cae

Please sign in to comment.