Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add KeyCredential and SASCredential types #21553

Merged
merged 5 commits into from
Sep 15, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions sdk/azcore/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,11 @@

### Features Added

* Added types `KeyCredential` and `SASCredential` to the `azcore` package.
* Includes their respective constructor functions.
* Added types `KeyCredentialPolicy` and `SASCredentialPolicy` to the `azcore/runtime` package.
* Includes their respective constructor functions and options types.

### Breaking Changes

### Bugs Fixed
Expand Down
18 changes: 18 additions & 0 deletions sdk/azcore/core.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,24 @@ type AccessToken = exported.AccessToken
// TokenCredential represents a credential capable of providing an OAuth token.
type TokenCredential = exported.TokenCredential

// KeyCredential contains an authentication key used to authenticate to an Azure service.
type KeyCredential = exported.KeyCredential

// NewKeyCredential creates a new instance of [KeyCredential] with the specified values.
// - key is the authentication key
func NewKeyCredential(key string) *KeyCredential {
return exported.NewKeyCredential(key)
}

// SASCredential contains a shared access signature used to authenticate to an Azure service.
type SASCredential = exported.SASCredential

// NewSASCredential creates a new instance of [SASCredential] with the specified values.
// - sas is the shared access signature
func NewSASCredential(sas string) *SASCredential {
return exported.NewSASCredential(sas)
}

// holds sentinel values used to send nulls
var nullables map[reflect.Type]interface{} = map[reflect.Type]interface{}{}

Expand Down
8 changes: 8 additions & 0 deletions sdk/azcore/core_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -221,3 +221,11 @@ func TestClientWithClientName(t *testing.T) {
require.NoError(t, err)
require.EqualValues(t, "az.namespace:Widget.Factory", attrString)
}

func TestNewKeyCredential(t *testing.T) {
require.NotNil(t, NewKeyCredential("foo"))
}

func TestNewSASCredential(t *testing.T) {
require.NotNil(t, NewSASCredential("foo"))
}
63 changes: 63 additions & 0 deletions sdk/azcore/internal/exported/exported.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
"fmt"
"io"
"net/http"
"sync/atomic"
"time"
)

Expand Down Expand Up @@ -110,3 +111,65 @@ func DecodeByteArray(s string, v *[]byte, format Base64Encoding) error {
return fmt.Errorf("unrecognized byte array format: %d", format)
}
}

// KeyCredential contains an authentication key used to authenticate to an Azure service.
// Exported as azcore.KeyCredential.
type KeyCredential struct {
cred *keyCredential
}

// NewKeyCredential creates a new instance of [KeyCredential] with the specified values.
// - key is the authentication key
func NewKeyCredential(key string) *KeyCredential {
return &KeyCredential{cred: newKeyCredential(key)}
}

// Update replaces the existing key with the specified value.
func (k *KeyCredential) Update(key string) {
k.cred.Update(key)
}

// SASCredential contains a shared access signature used to authenticate to an Azure service.
// Exported as azcore.SASCredential.
type SASCredential struct {
cred *keyCredential
}

// NewSASCredential creates a new instance of [SASCredential] with the specified values.
// - sas is the shared access signature
func NewSASCredential(sas string) *SASCredential {
return &SASCredential{cred: newKeyCredential(sas)}
}

// Update replaces the existing shared access signature with the specified value.
func (k *SASCredential) Update(sas string) {
k.cred.Update(sas)
}

// KeyCredentialGet returns the key for cred.
func KeyCredentialGet(cred *KeyCredential) string {
return cred.cred.Get()
}

// SASCredentialGet returns the shared access sig for cred.
func SASCredentialGet(cred *SASCredential) string {
return cred.cred.Get()
}

type keyCredential struct {
key atomic.Value // string
}

func newKeyCredential(key string) *keyCredential {
keyCred := keyCredential{}
keyCred.key.Store(key)
return &keyCred
}

func (k *keyCredential) Get() string {
return k.key.Load().(string)
}

func (k *keyCredential) Update(key string) {
k.key.Store(key)
}
41 changes: 41 additions & 0 deletions sdk/azcore/internal/exported/exported_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,12 @@
package exported

import (
"fmt"
"net/http"
"strings"
"testing"

"github.com/stretchr/testify/require"
)

func TestNopCloser(t *testing.T) {
Expand All @@ -33,3 +36,41 @@ func TestHasStatusCode(t *testing.T) {
t.Fatal("unexpected failure")
}
}

func TestDecodeByteArray(t *testing.T) {
out := []byte{}
require.NoError(t, DecodeByteArray("", &out, Base64StdFormat))
require.Empty(t, out)
const (
stdEncoding = "VGVzdERlY29kZUJ5dGVBcnJheQ=="
urlEncoding = "VGVzdERlY29kZUJ5dGVBcnJheQ"
decoded = "TestDecodeByteArray"
)
require.NoError(t, DecodeByteArray(stdEncoding, &out, Base64StdFormat))
require.EqualValues(t, decoded, string(out))
require.NoError(t, DecodeByteArray(urlEncoding, &out, Base64URLFormat))
require.EqualValues(t, decoded, string(out))
require.NoError(t, DecodeByteArray(fmt.Sprintf("\"%s\"", stdEncoding), &out, Base64StdFormat))
require.EqualValues(t, decoded, string(out))
require.Error(t, DecodeByteArray(stdEncoding, &out, 123))
}

func TestNewKeyCredential(t *testing.T) {
const val1 = "foo"
cred := NewKeyCredential(val1)
require.NotNil(t, cred)
require.EqualValues(t, val1, KeyCredentialGet(cred))
const val2 = "bar"
cred.Update(val2)
require.EqualValues(t, val2, KeyCredentialGet(cred))
}

func TestNewSASCredential(t *testing.T) {
const val1 = "foo"
cred := NewSASCredential(val1)
require.NotNil(t, cred)
require.EqualValues(t, val1, SASCredentialGet(cred))
const val2 = "bar"
cred.Update(val2)
require.EqualValues(t, val2, SASCredentialGet(cred))
}
49 changes: 49 additions & 0 deletions sdk/azcore/runtime/policy_key_credential.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

package runtime

import (
"net/http"

"github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/exported"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/policy"
)

// KeyCredentialPolicy authorizes requests with a [azcore.KeyCredential].
type KeyCredentialPolicy struct {
cred *exported.KeyCredential
header string
prefix string
}

// KeyCredentialPolicyOptions contains the optional values configuring [KeyCredentialPolicy].
type KeyCredentialPolicyOptions struct {
// Prefix is used if the key requires a prefix before it's inserted into the HTTP request.
Prefix string
}

// NewKeyCredentialPolicy creates a new instance of [KeyCredentialPolicy].
// - cred is the [azcore.KeyCredential] used to authenticate with the service
// - header is the name of the HTTP request header in which the key is placed
// - options contains optional configuration, pass nil to accept the default values
func NewKeyCredentialPolicy(cred *exported.KeyCredential, header string, options *KeyCredentialPolicyOptions) *KeyCredentialPolicy {
if options == nil {
options = &KeyCredentialPolicyOptions{}
}
return &KeyCredentialPolicy{
cred: cred,
header: header,
prefix: options.Prefix,
}
}

// Do implementes the Do method on the [policy.Polilcy] interface.
func (k *KeyCredentialPolicy) Do(req *policy.Request) (*http.Response, error) {
val := exported.KeyCredentialGet(k.cred)
if k.prefix != "" {
val = k.prefix + val
}
req.Raw().Header.Add(k.header, val)
return req.Next()
}
50 changes: 50 additions & 0 deletions sdk/azcore/runtime/policy_key_credential_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

package runtime

import (
"context"
"net/http"
"testing"

"github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/exported"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/shared"
"github.com/stretchr/testify/require"
)

func TestKeyCredentialPolicy(t *testing.T) {
const key = "foo"
cred := exported.NewKeyCredential(key)

const headerName = "fake-auth"
policy := NewKeyCredentialPolicy(cred, headerName, nil)
require.NotNil(t, policy)

pl := exported.NewPipeline(shared.TransportFunc(func(req *http.Request) (*http.Response, error) {
require.EqualValues(t, key, req.Header.Get(headerName))
return &http.Response{}, nil
}), policy)

req, err := NewRequest(context.Background(), http.MethodGet, "http://contoso.com")
require.NoError(t, err)

_, err = pl.Do(req)
require.NoError(t, err)

policy = NewKeyCredentialPolicy(cred, headerName, &KeyCredentialPolicyOptions{
Prefix: "Prefix: ",
})
require.NotNil(t, policy)

pl = exported.NewPipeline(shared.TransportFunc(func(req *http.Request) (*http.Response, error) {
require.EqualValues(t, "Prefix: "+key, req.Header.Get(headerName))
return &http.Response{}, nil
}), policy)

req, err = NewRequest(context.Background(), http.MethodGet, "http://contoso.com")
require.NoError(t, err)

_, err = pl.Do(req)
require.NoError(t, err)
}
39 changes: 39 additions & 0 deletions sdk/azcore/runtime/policy_sas_credential.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

package runtime

import (
"net/http"

"github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/exported"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/policy"
)

// SASCredentialPolicy authorizes requests with a [azcore.SASCredential].
type SASCredentialPolicy struct {
cred *exported.SASCredential
header string
}

// SASCredentialPolicyOptions contains the optional values configuring [SASCredentialPolicy].
type SASCredentialPolicyOptions struct {
// placeholder for future optional values
}

// NewSASCredentialPolicy creates a new instance of [SASCredentialPolicy].
// - cred is the [azcore.SASCredential] used to authenticate with the service
// - header is the name of the HTTP request header in which the shared access signature is placed
// - options contains optional configuration, pass nil to accept the default values
func NewSASCredentialPolicy(cred *exported.SASCredential, header string, options *SASCredentialPolicyOptions) *SASCredentialPolicy {
return &SASCredentialPolicy{
cred: cred,
header: header,
}
}

// Do implementes the Do method on the [policy.Polilcy] interface.
func (k *SASCredentialPolicy) Do(req *policy.Request) (*http.Response, error) {
req.Raw().Header.Add(k.header, exported.SASCredentialGet(k.cred))
return req.Next()
}
34 changes: 34 additions & 0 deletions sdk/azcore/runtime/polilcy_sas_credential_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

package runtime

import (
"context"
"net/http"
"testing"

"github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/exported"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/shared"
"github.com/stretchr/testify/require"
)

func TestSASCredentialPolicy(t *testing.T) {
const key = "foo"
cred := exported.NewSASCredential(key)

const headerName = "fake-auth"
policy := NewSASCredentialPolicy(cred, headerName, nil)
require.NotNil(t, policy)

pl := exported.NewPipeline(shared.TransportFunc(func(req *http.Request) (*http.Response, error) {
require.EqualValues(t, key, req.Header.Get(headerName))
return &http.Response{}, nil
}), policy)

req, err := NewRequest(context.Background(), http.MethodGet, "http://contoso.com")
require.NoError(t, err)

_, err = pl.Do(req)
require.NoError(t, err)
}