diff --git a/README.md b/README.md index 33f0a2f..a2024ce 100644 --- a/README.md +++ b/README.md @@ -41,8 +41,9 @@ * [Create](#createuserdefinedfunction) * [Replace](#replaceuserdefinedfunction) * [Delete](#deleteuserdefinedfunction) - * [Iterator](#iterator) +* [Iterator](#iterator) * [DocumentIterator](#documentIterator) +* [Authentication with Azure AD](#authenticationwithazuread) ### Get Started @@ -446,6 +447,39 @@ func main() { } ``` +### Authentication with Azure AD + +You can authenticate with Cosmos DB using Azure AD and a service principal, including full RBAC support. To configure Cosmos DB to use Azure AD, take a look at the [Cosmos DB documentation](https://docs.microsoft.com/en-us/azure/cosmos-db/how-to-setup-rbac). + +To use this library with a service principal: + +```go +import ( + "github.com/Azure/go-autorest/autorest/adal" + "github.com/a8m/documentdb" +) + +func main() { + // Azure AD application (service principal) client credentials + tenantId := "tenant-id" + clientId := "client-id" + clientSecret := "client-secret" + + // Azure AD endpoint may be different for sovereign clouds + oauthConfig, err := adal.NewOAuthConfig("https://login.microsoftonline.com/", tenantId) + if err != nil { + log.Fatal(err) + } + spt, err := adal.NewServicePrincipalToken(*oauthConfig, clientId, clientSecret, "https://cosmos.azure.com") // Always "https://cosmos.azure.com" + if err != nil { + log.Fatal(err) + } + + config := documentdb.NewConfigWithServicePrincipal(spt) + client := documentdb.New("connection-url", config) +} +``` + ### Examples * [Go DocumentDB Example](https://github.com/a8m/go-documentdb-example) - A users CRUD application using Martini and DocumentDB diff --git a/client.go b/client.go index 3131aec..1ad69b4 100644 --- a/client.go +++ b/client.go @@ -22,8 +22,8 @@ type Client struct { http.Client } -func (c *Client) apply(r *Request, opts []CallOption) (err error) { - if err = r.DefaultHeaders(c.Config.MasterKey); err != nil { +func (c *Client) apply(r *Request, opts []CallOption) (err error) { + if err = r.DefaultHeaders(c.Config); err != nil { return err } diff --git a/documentdb.go b/documentdb.go index 51b2f6b..140bca6 100644 --- a/documentdb.go +++ b/documentdb.go @@ -12,6 +12,8 @@ import ( "net/http" "reflect" "sync" + + "github.com/Azure/go-autorest/autorest/adal" ) var buffers = &sync.Pool{ @@ -34,6 +36,7 @@ func DefaultIdentificationHydrator(config *Config, doc interface{}) { type Config struct { MasterKey *Key + ServicePrincipal *adal.ServicePrincipalToken Client http.Client IdentificationHydrator IdentificationHydrator IdentificationPropertyName string @@ -47,6 +50,15 @@ func NewConfig(key *Key) *Config { } } +// NewConfigWithServicePrincipal creates a new Config object that uses Azure AD (via a service principal) for authentication +func NewConfigWithServicePrincipal(servicePrincipal *adal.ServicePrincipalToken) *Config { + return &Config{ + ServicePrincipal: servicePrincipal, + IdentificationHydrator: DefaultIdentificationHydrator, + IdentificationPropertyName: "Id", + } +} + // WithClient stores given http client for later use by documentdb client. func (c *Config) WithClient(client http.Client) *Config { c.Client = client diff --git a/go.mod b/go.mod index 4a4d237..5e02418 100644 --- a/go.mod +++ b/go.mod @@ -1,6 +1,7 @@ module github.com/a8m/documentdb require ( + github.com/Azure/go-autorest/autorest/adal v0.9.14 github.com/davecgh/go-spew v1.1.1 // indirect github.com/json-iterator/go v1.1.5 github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect diff --git a/go.sum b/go.sum index 0e9d3f6..60245ad 100644 --- a/go.sum +++ b/go.sum @@ -1,5 +1,19 @@ +github.com/Azure/go-autorest v14.2.0+incompatible h1:V5VMDjClD3GiElqLWO7mz2MxNAK/vTfRHdAubSIPRgs= +github.com/Azure/go-autorest v14.2.0+incompatible/go.mod h1:r+4oMnoxhatjLLJ6zxSWATqVooLgysK6ZNox3g/xq24= +github.com/Azure/go-autorest/autorest/adal v0.9.14 h1:G8hexQdV5D4khOXrWG2YuLCFKhWYmWD8bHYaXN5ophk= +github.com/Azure/go-autorest/autorest/adal v0.9.14/go.mod h1:W/MM4U6nLxnIskrw4UwWzlHfGjwUS50aOsc/I3yuU8M= +github.com/Azure/go-autorest/autorest/date v0.3.0 h1:7gUk1U5M/CQbp9WoqinNzJar+8KY+LPI6wiWrP/myHw= +github.com/Azure/go-autorest/autorest/date v0.3.0/go.mod h1:BI0uouVdmngYNUzGWeSYnokU+TrmwEsOqdt8Y6sso74= +github.com/Azure/go-autorest/autorest/mocks v0.4.1 h1:K0laFcLE6VLTOwNgSxaGbUcLPuGXlNkbVvq4cW4nIHk= +github.com/Azure/go-autorest/autorest/mocks v0.4.1/go.mod h1:LTp+uSrOhSkaKrUy935gNZuuIPPVsHlr9DSOxSayd+k= +github.com/Azure/go-autorest/logger v0.2.1 h1:IG7i4p/mDa2Ce4TRyAO8IHnVhAVF3RFU+ZtXWSmf4Tg= +github.com/Azure/go-autorest/logger v0.2.1/go.mod h1:T9E3cAhj2VqvPOtCYAvby9aBXkZmbF5NWuPV8+WeEW8= +github.com/Azure/go-autorest/tracing v0.6.0 h1:TYi4+3m5t6K48TGI9AUdb+IzbnSxvnvUMfuitfgcfuo= +github.com/Azure/go-autorest/tracing v0.6.0/go.mod h1:+vhtPC754Xsa23ID7GlGsrdKBpUA79WCAKPPZVC2DeU= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/form3tech-oss/jwt-go v3.2.2+incompatible h1:TcekIExNqud5crz4xD2pavyTgWiPvpYe4Xau31I0PRk= +github.com/form3tech-oss/jwt-go v3.2.2+incompatible/go.mod h1:pbq4aXjuKjdthFRnoDwaVPLA+WlJuPGy+QneDUgJi2k= github.com/json-iterator/go v1.1.5 h1:gL2yXlmiIo4+t+y32d4WGwOjKGYcGOuyrg46vadswDE= github.com/json-iterator/go v1.1.5/go.mod h1:+SdeFBvtyEkXs7REEP0seUULqWtbJapLOCVDaaPEHmU= github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg= @@ -12,3 +26,10 @@ github.com/stretchr/objx v0.1.1 h1:2vfRuCMp5sSVIDSqO8oNnWJq7mPa6KVP3iPIwFBuy8A= github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/testify v1.2.2 h1:bSDNvY7ZPG5RlJ8otE/7V6gMiyenm9RtJ7IUVIAoJ1w= github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= +golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/crypto v0.0.0-20201002170205-7f63de1d35b0 h1:hb9wdF1z5waM+dSIICn1l0DkLVDT3hqhhQsDNUmHPRE= +golang.org/x/crypto v0.0.0-20201002170205-7f63de1d35b0/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= +golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= diff --git a/request.go b/request.go index 322864e..a9fd971 100644 --- a/request.go +++ b/request.go @@ -60,32 +60,42 @@ func ResourceRequest(link string, req *http.Request) *Request { // Add 3 default headers to *Request // "x-ms-date", "x-ms-version", "authorization" -func (req *Request) DefaultHeaders(mKey *Key) (err error) { +func (req *Request) DefaultHeaders(config *Config) (err error) { req.Header.Add(HeaderXDate, formatDate(time.Now())) req.Header.Add(HeaderVersion, SupportedVersion) - b := buffers.Get().(*bytes.Buffer) - b.Reset() - b.WriteString(req.Method) - b.WriteRune('\n') - b.WriteString(req.rType) - b.WriteRune('\n') - b.WriteString(req.rId) - b.WriteRune('\n') - b.WriteString(req.Header.Get(HeaderXDate)) - b.WriteRune('\n') - b.WriteString(req.Header.Get("Date")) - b.WriteRune('\n') - - sign, err := authorize(bytes.ToLower(b.Bytes()), mKey) - if err != nil { - return err + // Authentication via master key + if config.MasterKey != nil && config.MasterKey.Key != "" { + b := buffers.Get().(*bytes.Buffer) + b.Reset() + b.WriteString(req.Method) + b.WriteRune('\n') + b.WriteString(req.rType) + b.WriteRune('\n') + b.WriteString(req.rId) + b.WriteRune('\n') + b.WriteString(req.Header.Get(HeaderXDate)) + b.WriteRune('\n') + b.WriteString(req.Header.Get("Date")) + b.WriteRune('\n') + + sign, err := authorize(bytes.ToLower(b.Bytes()), config.MasterKey) + if err != nil { + return err + } + + buffers.Put(b) + + req.Header.Add(HeaderAuth, url.QueryEscape("type=master&ver=1.0&sig="+sign)) + } else if config.ServicePrincipal != nil { + err := config.ServicePrincipal.EnsureFresh() + if err != nil { + return err + } + token := config.ServicePrincipal.OAuthToken() + req.Header.Add(HeaderAuth, url.QueryEscape("type=aad&ver=1.0&sig="+token)) } - buffers.Put(b) - - req.Header.Add(HeaderAuth, url.QueryEscape("type=master&ver=1.0&sig="+sign)) - return }