diff --git a/pkg/security/auth.go b/pkg/security/auth.go index 79d818346e91..7fbfd497c219 100644 --- a/pkg/security/auth.go +++ b/pkg/security/auth.go @@ -114,15 +114,18 @@ func GetCertificateUserScope( // https://github.com/golang/go/blob/go1.8.1/src/crypto/tls/handshake_server.go#L723:L742 peerCert := tlsState.PeerCertificates[0] for _, uri := range peerCert.URIs { - tenantID, user, err := ParseTenantURISAN(uri.String()) - if err != nil { - return nil, err - } - scope := CertificateUserScope{ - Username: user, - TenantID: tenantID, + uriString := uri.String() + if URISANHasCRDBPrefix(uriString) { + tenantID, user, err := ParseTenantURISAN(uriString) + if err != nil { + return nil, err + } + scope := CertificateUserScope{ + Username: user, + TenantID: tenantID, + } + userScopes = append(userScopes, scope) } - userScopes = append(userScopes, scope) } if len(userScopes) == 0 { users := getCertificatePrincipals(peerCert) diff --git a/pkg/security/auth_test.go b/pkg/security/auth_test.go index d1292879a30d..9dd11dd95889 100644 --- a/pkg/security/auth_test.go +++ b/pkg/security/auth_test.go @@ -85,58 +85,98 @@ func makeFakeTLSState(t *testing.T, spec string) *tls.ConnectionState { func TestGetCertificateUserScope(t *testing.T) { defer leaktest.AfterTest(t)() - // Nil TLS state. - if _, err := security.GetCertificateUserScope(nil); err == nil { - t.Error("unexpected success") - } + t.Run("nil TLS state", func(t *testing.T) { + if _, err := security.GetCertificateUserScope(nil); err == nil { + t.Error("unexpected success") + } + }) - // No certificates. - if _, err := security.GetCertificateUserScope(makeFakeTLSState(t, "")); err == nil { - t.Error("unexpected success") - } + t.Run("no certificates", func(t *testing.T) { + if _, err := security.GetCertificateUserScope(makeFakeTLSState(t, "")); err == nil { + t.Error("unexpected success") + } + }) - // Good request: single certificate. - if userScopes, err := security.GetCertificateUserScope(makeFakeTLSState(t, "foo")); err != nil { - t.Error(err) - } else { - require.Equal(t, 1, len(userScopes)) - require.Equal(t, "foo", userScopes[0].Username) - require.True(t, userScopes[0].Global) - } + t.Run("good request: single certificate", func(t *testing.T) { + if userScopes, err := security.GetCertificateUserScope(makeFakeTLSState(t, "foo")); err != nil { + t.Error(err) + } else { + require.Equal(t, 1, len(userScopes)) + require.Equal(t, "foo", userScopes[0].Username) + require.True(t, userScopes[0].Global) + } + }) - // Request with multiple certs, but only one chain (eg: origin certs are client and CA). - if userScopes, err := security.GetCertificateUserScope(makeFakeTLSState(t, "foo;CA")); err != nil { - t.Error(err) - } else { - require.Equal(t, 1, len(userScopes)) - require.Equal(t, "foo", userScopes[0].Username) - require.True(t, userScopes[0].Global) - } + t.Run("request with multiple certs, but only one chain (eg: origin certs are client and CA)", func(t *testing.T) { + if userScopes, err := security.GetCertificateUserScope(makeFakeTLSState(t, "foo;CA")); err != nil { + t.Error(err) + } else { + require.Equal(t, 1, len(userScopes)) + require.Equal(t, "foo", userScopes[0].Username) + require.True(t, userScopes[0].Global) + } + }) - // Always use the first certificate. - if userScopes, err := security.GetCertificateUserScope(makeFakeTLSState(t, "foo;bar")); err != nil { - t.Error(err) - } else { - require.Equal(t, 1, len(userScopes)) - require.Equal(t, "foo", userScopes[0].Username) - require.True(t, userScopes[0].Global) - } + t.Run("always use the first certificate", func(t *testing.T) { + if userScopes, err := security.GetCertificateUserScope(makeFakeTLSState(t, "foo;bar")); err != nil { + t.Error(err) + } else { + require.Equal(t, 1, len(userScopes)) + require.Equal(t, "foo", userScopes[0].Username) + require.True(t, userScopes[0].Global) + } + }) - // Extract all of the principals from the first certificate. - if userScopes, err := security.GetCertificateUserScope(makeFakeTLSState(t, "foo,dns:bar,dns:blah;CA")); err != nil { - t.Error(err) - } else { - require.Equal(t, 3, len(userScopes)) - require.True(t, userScopes[0].Global) - } - if userScopes, err := security.GetCertificateUserScope(makeFakeTLSState(t, "foo,uri:crdb://tenant/123/user/foo;CA")); err != nil { - t.Error(err) - } else { - require.Equal(t, 1, len(userScopes)) - require.Equal(t, "foo", userScopes[0].Username) - require.Equal(t, roachpb.MakeTenantID(123), userScopes[0].TenantID) - require.False(t, userScopes[0].Global) - } + t.Run("extract all of the principals from the first certificate", func(t *testing.T) { + if userScopes, err := security.GetCertificateUserScope(makeFakeTLSState(t, "foo,dns:bar,dns:blah;CA")); err != nil { + t.Error(err) + } else { + require.Equal(t, 3, len(userScopes)) + require.True(t, userScopes[0].Global) + } + }) + + t.Run("extracts username, tenantID from tenant URI SAN", func(t *testing.T) { + if userScopes, err := security.GetCertificateUserScope( + makeFakeTLSState(t, "foo,uri:crdb://tenant/123/user/foo;CA")); err != nil { + t.Error(err) + } else { + require.Equal(t, 1, len(userScopes)) + require.Equal(t, "foo", userScopes[0].Username) + require.Equal(t, roachpb.MakeTenantID(123), userScopes[0].TenantID) + require.False(t, userScopes[0].Global) + } + }) + + t.Run("extracts tenant URI SAN even when multiple URIs, where one URI is not of CRBD format", func(t *testing.T) { + if userScopes, err := security.GetCertificateUserScope( + makeFakeTLSState(t, "foo,uri:mycompany:sv:rootclient:dev:usw1,uri:crdb://tenant/123/user/foo;CA")); err != nil { + t.Error(err) + } else { + require.Equal(t, 1, len(userScopes)) + require.Equal(t, "foo", userScopes[0].Username) + require.Equal(t, roachpb.MakeTenantID(123), userScopes[0].TenantID) + require.False(t, userScopes[0].Global) + } + }) + + t.Run("errors when tenant URI SAN is not of expected format, even if other URI SAN is provided", func(t *testing.T) { + userScopes, err := security.GetCertificateUserScope( + makeFakeTLSState(t, "foo,uri:mycompany:sv:rootclient:dev:usw1,uri:crdb://tenant/bad/format/123;CA")) + require.Nil(t, userScopes) + require.ErrorContains(t, err, "invalid tenant URI SAN") + }) + + t.Run("falls back to global client cert when crdb URI SAN scheme is not followed", func(t *testing.T) { + if userScopes, err := security.GetCertificateUserScope( + makeFakeTLSState(t, "sanuri,uri:mycompany:sv:rootclient:dev:usw1;CA")); err != nil { + t.Error(err) + } else { + require.Equal(t, 1, len(userScopes)) + require.Equal(t, "sanuri", userScopes[0].Username) + require.True(t, userScopes[0].Global) + } + }) } func TestSetCertPrincipalMap(t *testing.T) { diff --git a/pkg/security/x509.go b/pkg/security/x509.go index 7e7dbc65f2d1..413a4a040a66 100644 --- a/pkg/security/x509.go +++ b/pkg/security/x509.go @@ -38,7 +38,8 @@ const ( validFrom = -time.Hour * 24 maxPathLength = 1 caCommonName = "Cockroach CA" - tenantURISANFormatString = "crdb://tenant/%d/user/%s" + tenantURISANPrefixString = "crdb://" + tenantURISANFormatString = tenantURISANPrefixString + "tenant/%d/user/%s" // TenantsOU is the OrganizationalUnit that determines a client certificate should be treated as a tenant client // certificate (as opposed to a KV node client certificate). @@ -333,6 +334,11 @@ func MakeTenantURISANs( return urls, nil } +// URISANHasCRDBPrefix indicates whether a URI string has the tenant URI SAN prefix. +func URISANHasCRDBPrefix(rawlURI string) bool { + return strings.HasPrefix(rawlURI, tenantURISANPrefixString) +} + // ParseTenantURISAN extracts the user and tenant ID contained within a tenant URI SAN. func ParseTenantURISAN(rawURL string) (roachpb.TenantID, string, error) { r := strings.NewReader(rawURL)