Skip to content

Commit

Permalink
Azure AD authentication for the Postgres backend (#15757)
Browse files Browse the repository at this point in the history
* Add Username to sqlbk and don't leak connConfigs

* Azure AD authentication for sqlbk/Postgres

* Add a Postgres Config test

* Cache Azure tokens, document azureBeforeConnect

* Move the config test to sqlbk

* go mod tidy

* go get azcore azidentity
  • Loading branch information
espadolini authored Sep 15, 2022
1 parent 7695f23 commit 33c6d82
Show file tree
Hide file tree
Showing 8 changed files with 309 additions and 52 deletions.
12 changes: 6 additions & 6 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@ require (
cloud.google.com/go/firestore v1.6.1
cloud.google.com/go/iam v0.3.0
cloud.google.com/go/storage v1.23.0
github.com/Azure/azure-sdk-for-go/sdk/azcore v1.0.0
github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.0.0
github.com/Azure/azure-sdk-for-go/sdk/azcore v1.1.3
github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.1.0
github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/mysql/armmysql v1.0.0
github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/postgresql/armpostgresql v1.0.0
github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/subscription/armsubscription v1.0.0
Expand Down Expand Up @@ -42,6 +42,7 @@ require (
github.com/ghodss/yaml v1.0.0
github.com/gizak/termui/v3 v3.1.0
github.com/go-ldap/ldap/v3 v3.4.1
github.com/go-logr/logr v1.2.3
github.com/go-mysql-org/go-mysql v1.5.0
github.com/go-redis/redis/v8 v8.11.4
github.com/gobuffalo/flect v0.2.5
Expand Down Expand Up @@ -139,7 +140,9 @@ require (
k8s.io/apiserver v0.24.2
k8s.io/cli-runtime v0.24.0
k8s.io/client-go v0.24.2
k8s.io/klog/v2 v2.60.1
k8s.io/kubectl v0.24.0
k8s.io/utils v0.0.0-20220210201930-3a6ce19ff2f9
sigs.k8s.io/controller-runtime v0.12.3
sigs.k8s.io/controller-tools v0.9.2
sigs.k8s.io/yaml v1.3.0
Expand All @@ -158,7 +161,7 @@ require (
github.com/Azure/go-autorest/logger v0.2.1 // indirect
github.com/Azure/go-autorest/tracing v0.6.0 // indirect
github.com/Azure/go-ntlmssp v0.0.0-20200615164410-66371956d46c // indirect
github.com/AzureAD/microsoft-authentication-library-for-go v0.4.0 // indirect
github.com/AzureAD/microsoft-authentication-library-for-go v0.5.1 // indirect
github.com/MakeNowJust/heredoc v0.0.0-20170808103936-bb23615498cd // indirect
github.com/PuerkitoBio/purell v1.1.1 // indirect
github.com/PuerkitoBio/urlesc v0.0.0-20170810143723-de5bf2ad4578 // indirect
Expand Down Expand Up @@ -199,7 +202,6 @@ require (
github.com/gabriel-vasile/mimetype v1.4.0 // indirect
github.com/go-asn1-ber/asn1-ber v1.5.1 // indirect
github.com/go-errors/errors v1.0.1 // indirect
github.com/go-logr/logr v1.2.3 // indirect
github.com/go-logr/stdr v1.2.2 // indirect
github.com/go-logr/zapr v1.2.0 // indirect
github.com/go-openapi/jsonpointer v0.19.5 // indirect
Expand Down Expand Up @@ -306,9 +308,7 @@ require (
gopkg.in/inf.v0 v0.9.1 // indirect
gopkg.in/mgo.v2 v2.0.0-20190816093944-a6b53ec6cb22 // indirect
k8s.io/component-base v0.24.2 // indirect
k8s.io/klog/v2 v2.60.1 // indirect
k8s.io/kube-openapi v0.0.0-20220328201542-3ee0da9b0b42 // indirect
k8s.io/utils v0.0.0-20220210201930-3a6ce19ff2f9 // indirect
launchpad.net/gocheck v0.0.0-20140225173054-000000000087 // indirect
sigs.k8s.io/json v0.0.0-20211208200746-9f7c6b3444d2 // indirect
sigs.k8s.io/kustomize/api v0.11.4 // indirect
Expand Down
12 changes: 6 additions & 6 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -69,11 +69,11 @@ gioui.org v0.0.0-20210308172011-57750fc8a0a6/go.mod h1:RSH6KIUZ0p2xy5zHDxgAM4zum
github.com/Azure/azure-pipeline-go v0.2.3 h1:7U9HBg1JFK3jHl5qmo4CTZKFTVgMwdFHMVtCdfBE21U=
github.com/Azure/azure-pipeline-go v0.2.3/go.mod h1:x841ezTBIMG6O3lAcl8ATHnsOPVl2bqk7S3ta6S6u4k=
github.com/Azure/azure-sdk-for-go/sdk/azcore v0.19.0/go.mod h1:h6H6c8enJmmocHUbLiiGY6sx7f9i+X3m1CHdd5c6Rdw=
github.com/Azure/azure-sdk-for-go/sdk/azcore v1.0.0 h1:sVPhtT2qjO86rTUaWMr4WoES4TkjGnzcioXcnHV9s5k=
github.com/Azure/azure-sdk-for-go/sdk/azcore v1.0.0/go.mod h1:uGG2W01BaETf0Ozp+QxxKJdMBNRWPdstHG0Fmdwn1/U=
github.com/Azure/azure-sdk-for-go/sdk/azcore v1.1.3 h1:8LoU8N2lIUzkmstvwXvVfniMZlFbesfT2AmA1aqvRr8=
github.com/Azure/azure-sdk-for-go/sdk/azcore v1.1.3/go.mod h1:uGG2W01BaETf0Ozp+QxxKJdMBNRWPdstHG0Fmdwn1/U=
github.com/Azure/azure-sdk-for-go/sdk/azidentity v0.11.0/go.mod h1:HcM1YX14R7CJcghJGOYCgdezslRSVzqwLf/q+4Y2r/0=
github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.0.0 h1:Yoicul8bnVdQrhDMTHxdEckRGX01XvwXDHUT9zYZ3k0=
github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.0.0/go.mod h1:+6sju8gk8FRmSajX3Oz4G5Gm7P+mbqE9FVaXXFYTkCM=
github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.1.0 h1:QkAcEIAKbNL4KoFr4SathZPhDhF4mVwpBMFlYjyAqy8=
github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.1.0/go.mod h1:bhXu1AjYL+wutSL/kpSq6s7733q2Rb0yuot9Zgfqa/0=
github.com/Azure/azure-sdk-for-go/sdk/internal v0.7.0/go.mod h1:yqy467j36fJxcRV2TzfVZ1pCb5vxm4BtZPUdYWe/Xo8=
github.com/Azure/azure-sdk-for-go/sdk/internal v1.0.0 h1:jp0dGvZ7ZK0mgqnTSClMxa5xuRL7NZgHameVYF6BurY=
github.com/Azure/azure-sdk-for-go/sdk/internal v1.0.0/go.mod h1:eWRD7oawr1Mu1sLCawqVc0CUiF43ia3qQMxLscsKQ9w=
Expand Down Expand Up @@ -103,8 +103,8 @@ github.com/Azure/go-autorest/tracing v0.6.0 h1:TYi4+3m5t6K48TGI9AUdb+IzbnSxvnvUM
github.com/Azure/go-autorest/tracing v0.6.0/go.mod h1:+vhtPC754Xsa23ID7GlGsrdKBpUA79WCAKPPZVC2DeU=
github.com/Azure/go-ntlmssp v0.0.0-20200615164410-66371956d46c h1:/IBSNwUN8+eKzUzbJPqhK839ygXJ82sde8x3ogr6R28=
github.com/Azure/go-ntlmssp v0.0.0-20200615164410-66371956d46c/go.mod h1:chxPXzSsl7ZWRAuOIE23GDNzjWuZquvFlgA8xmpunjU=
github.com/AzureAD/microsoft-authentication-library-for-go v0.4.0 h1:WVsrXCnHlDDX8ls+tootqRE87/hL9S/g4ewig9RsD/c=
github.com/AzureAD/microsoft-authentication-library-for-go v0.4.0/go.mod h1:Vt9sXTKwMyGcOxSmLDMnGPgqsUg7m8pe215qMLrDXw4=
github.com/AzureAD/microsoft-authentication-library-for-go v0.5.1 h1:BWe8a+f/t+7KY7zH2mqygeUD0t8hNFXe08p1Pb3/jKE=
github.com/AzureAD/microsoft-authentication-library-for-go v0.5.1/go.mod h1:Vt9sXTKwMyGcOxSmLDMnGPgqsUg7m8pe215qMLrDXw4=
github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU=
github.com/BurntSushi/xgb v0.0.0-20160522181843-27f122750802/go.mod h1:IVnqGOEym/WlBOVXweHU+Q+/VP0lqqI8lqeDx9IjBqo=
github.com/Clever/go-utils v0.0.0-20180917210021-2dac0ec6f2ac h1:eoofDGlVjiTro1kr97QnRkW4b/MGDSmm4E0ta2rqFos=
Expand Down
91 changes: 91 additions & 0 deletions lib/backend/postgres/azure.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
// Copyright 2022 Gravitational, Inc
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package postgres

import (
"context"
"sync"
"time"

"github.com/Azure/azure-sdk-for-go/sdk/azcore"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/policy"
"github.com/Azure/azure-sdk-for-go/sdk/azidentity"
"github.com/gravitational/trace"
"github.com/jackc/pgx/v4"
"github.com/sirupsen/logrus"
)

// azureBeforeConnect will return a pgx BeforeConnect function suitable for
// Azure AD authentication. The returned function will set the password of the
// pgx.ConnConfig to a token for the relevant scope, fetching it and reusing it
// until expired (a burst of connections right at backend start is expected). If
// a client ID is provided, authentication will only be attempted as the managed
// identity with said ID rather than with all the default credentials.
func azureBeforeConnect(clientID string, log logrus.FieldLogger) (func(ctx context.Context, config *pgx.ConnConfig) error, error) {
var cred azcore.TokenCredential
if clientID != "" {
log.WithField("azure_client_id", clientID).Debug("Using Azure AD authentication with managed identity.")
c, err := azidentity.NewManagedIdentityCredential(&azidentity.ManagedIdentityCredentialOptions{
ID: azidentity.ClientID(clientID),
})
if err != nil {
return nil, trace.Wrap(err)
}
cred = c
} else {
log.Debug("Using Azure AD authentication with default credentials.")
c, err := azidentity.NewDefaultAzureCredential(nil)
if err != nil {
return nil, trace.Wrap(err)
}
cred = c
}

var mu sync.Mutex
var cachedToken azcore.AccessToken

beforeConnect := func(ctx context.Context, config *pgx.ConnConfig) error {
mu.Lock()
token := cachedToken
mu.Unlock()

// to account for clock drift between us, the database, and the IDMS,
// refresh the token 10 minutes before we think it will expire
if token.ExpiresOn.After(time.Now().Add(10 * time.Minute)) {
log.WithField("ttl", time.Until(token.ExpiresOn).String()).Debug("Reusing cached Azure access token.")
config.Password = token.Token
return nil
}

log.Debug("Fetching new Azure access token.")
token, err := cred.GetToken(ctx, policy.TokenRequestOptions{
Scopes: []string{"https://ossrdbms-aad.database.windows.net/.default"},
})
if err != nil {
return trace.Wrap(err)
}

log.WithField("ttl", time.Until(token.ExpiresOn).String()).Debug("Fetched Azure access token.")
config.Password = token.Token

mu.Lock()
cachedToken = token
mu.Unlock()

return nil
}

return beforeConnect, nil
}
8 changes: 3 additions & 5 deletions lib/backend/postgres/backend.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,13 @@ import (
"context"
"time"

"github.com/gravitational/teleport/api/utils"
"github.com/gravitational/teleport/lib/backend"
"github.com/gravitational/teleport/lib/backend/sqlbk"
"github.com/gravitational/trace"
"github.com/jonboulle/clockwork"
"github.com/sirupsen/logrus"

// Ensure pgx driver is registered.
_ "github.com/jackc/pgx/v4/stdlib"
"github.com/gravitational/teleport/api/utils"
"github.com/gravitational/teleport/lib/backend"
"github.com/gravitational/teleport/lib/backend/sqlbk"
)

const (
Expand Down
9 changes: 5 additions & 4 deletions lib/backend/postgres/backend_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,14 @@ import (
"testing"
"time"

"github.com/gravitational/teleport/api/utils"
"github.com/gravitational/teleport/lib/backend"
"github.com/gravitational/teleport/lib/backend/sqlbk"
"github.com/gravitational/trace"
"github.com/sirupsen/logrus"
"gopkg.in/yaml.v2"

"github.com/gravitational/teleport/api/utils"
"github.com/gravitational/teleport/lib/backend"
"github.com/gravitational/teleport/lib/backend/sqlbk"

"github.com/jackc/pgx/v4"
"github.com/stretchr/testify/require"
)
Expand Down Expand Up @@ -168,7 +169,7 @@ func TestDriverURL(t *testing.T) {
driver.cfg.TLS.ClientCertFile = "certfile"
driver.cfg.TLS.ClientKeyFile = "keyfile"

expect, err := url.Parse("postgres://host:123/database?sslmode=verify-full&sslrootcert=cafile&sslcert=certfile&sslkey=keyfile")
expect, err := url.Parse("postgres://host:123/database?user=&sslmode=verify-full&sslrootcert=cafile&sslcert=certfile&sslkey=keyfile")
require.NoError(t, err)
expectQuery := expect.Query()
expect.RawQuery = ""
Expand Down
34 changes: 26 additions & 8 deletions lib/backend/postgres/driver.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,12 @@ import (
"net/url"
"time"

"github.com/gravitational/teleport/lib/backend/sqlbk"
"github.com/gravitational/trace"
"github.com/jackc/pgconn"
"github.com/jackc/pgx/v4"
"github.com/jackc/pgx/v4/stdlib"

"github.com/gravitational/teleport/lib/backend/sqlbk"
)

// pgDriver implements backend.Driver for a PostgreSQL or CockroachDB database.
Expand Down Expand Up @@ -63,8 +64,17 @@ func (d *pgDriver) open(ctx context.Context, u *url.URL) (sqlbk.DB, error) {
}
connConfig.Logger = d.sqlLogger

// extract the user from the first client certificate in TLSConfig.
if connConfig.TLSConfig != nil {
beforeConnect := func(ctx context.Context, config *pgx.ConnConfig) error { return nil }
if d.cfg.Azure.Username != "" {
beforeConnect, err = azureBeforeConnect(d.cfg.Azure.ClientID, d.cfg.Log)
if err != nil {
return nil, trace.Wrap(err)
}
}

// Unless otherwise specified, extract the user from the first client
// certificate in TLSConfig.
if connConfig.User == "" && connConfig.TLSConfig != nil {
connConfig.User, err = tlsConfigUser(connConfig.TLSConfig)
if err != nil {
return nil, trace.Wrap(err)
Expand All @@ -75,17 +85,20 @@ func (d *pgDriver) open(ctx context.Context, u *url.URL) (sqlbk.DB, error) {
}

// Attempt to create backend database if it does not exist.
err = d.maybeCreateDatabase(ctx, connConfig)
if err != nil {
createConfig := *connConfig
// We have to do the beforeConnect dance manually here because it's not part
// of pgx.ConnConfig, it's only stored in connection pool configs.
if err := beforeConnect(ctx, &createConfig); err != nil {
return nil, trace.Wrap(err)
}

// Open connection/pool for backend database.
db, err := sql.Open("pgx", stdlib.RegisterConnConfig(connConfig))
err = d.maybeCreateDatabase(ctx, &createConfig)
if err != nil {
return nil, trace.Wrap(err)
}

// Open connection/pool for backend database.
db := stdlib.OpenDB(*connConfig, stdlib.OptionBeforeConnect(beforeConnect))

// Configure the connection pool.
db.SetConnMaxIdleTime(d.cfg.ConnMaxIdleTime)
db.SetConnMaxLifetime(d.cfg.ConnMaxLifetime)
Expand Down Expand Up @@ -155,6 +168,11 @@ func (d *pgDriver) url() *url.URL {
}
q := u.Query()
q.Set("sslmode", "verify-full")
user := d.cfg.TLS.Username
if d.cfg.Azure.Username != "" {
user = d.cfg.Azure.Username
}
q.Set("user", user)
q.Set("sslrootcert", d.cfg.TLS.CAFile)
q.Set("sslcert", d.cfg.TLS.ClientCertFile)
q.Set("sslkey", d.cfg.TLS.ClientKeyFile)
Expand Down
Loading

0 comments on commit 33c6d82

Please sign in to comment.