Skip to content

Commit

Permalink
[Cosmos] Adds global endpoint manager policy and links GEM to client (#…
Browse files Browse the repository at this point in the history
…22223)

* added policy and linked gem to client

* refactor go routine for CI pipeline lint

* remove debugging print statement

* changed to internal pipeline, updated tests

* removed unneeded code

* saved apiVersion into constant, changed all explicit uses to constant

* add gem policy emulator test

* small changes to the test

* clean up test

* pass preferred regions

* removed excess comments

* add nil check to client options for preferred regions
  • Loading branch information
simorenoh authored Jan 17, 2024
1 parent 2633a10 commit 58ca6be
Show file tree
Hide file tree
Showing 7 changed files with 145 additions and 60 deletions.
44 changes: 40 additions & 4 deletions sdk/data/azcosmos/cosmos_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,15 @@ import (
"github.com/Azure/azure-sdk-for-go/sdk/azcore/streaming"
)

const (
apiVersion = "2020-11-05"
)

// Client is used to interact with the Azure Cosmos DB database service.
type Client struct {
endpoint string
pipeline azruntime.Pipeline
gem *globalEndpointManager
}

// Endpoint used to create the client.
Expand All @@ -36,7 +41,15 @@ func (c *Client) Endpoint() string {
// cred - The credential used to authenticate with the cosmos service.
// options - Optional Cosmos client options. Pass nil to accept default values.
func NewClientWithKey(endpoint string, cred KeyCredential, o *ClientOptions) (*Client, error) {
return &Client{endpoint: endpoint, pipeline: newPipeline(newSharedKeyCredPolicy(cred), o)}, nil
preferredRegions := []string{}
if o != nil {
preferredRegions = o.PreferredRegions
}
gem, err := newGlobalEndpointManager(endpoint, newInternalPipeline(newSharedKeyCredPolicy(cred), o), preferredRegions, 0)
if err != nil {
return nil, err
}
return &Client{endpoint: endpoint, pipeline: newPipeline(newSharedKeyCredPolicy(cred), gem, o), gem: gem}, nil
}

// NewClient creates a new instance of Cosmos client with Azure AD access token authentication. It uses the default pipeline configuration.
Expand All @@ -48,7 +61,16 @@ func NewClient(endpoint string, cred azcore.TokenCredential, o *ClientOptions) (
if err != nil {
return nil, err
}
return &Client{endpoint: endpoint, pipeline: newPipeline(newCosmosBearerTokenPolicy(cred, scope, nil), o)}, nil
preferredRegions := []string{}
if o != nil {
preferredRegions = o.PreferredRegions
}
gem, err := newGlobalEndpointManager(endpoint, newInternalPipeline(newCosmosBearerTokenPolicy(cred, scope, nil), o), preferredRegions, 0)
if err != nil {
return nil, err
}

return &Client{endpoint: endpoint, pipeline: newPipeline(newCosmosBearerTokenPolicy(cred, scope, nil), gem, o), gem: gem}, nil
}

// NewClientFromConnectionString creates a new instance of Cosmos client from connection string. It uses the default pipeline configuration.
Expand Down Expand Up @@ -87,7 +109,7 @@ func NewClientFromConnectionString(connectionString string, o *ClientOptions) (*
return NewClientWithKey(endpoint, cred, o)
}

func newPipeline(authPolicy policy.Policy, options *ClientOptions) azruntime.Pipeline {
func newPipeline(authPolicy policy.Policy, gem *globalEndpointManager, options *ClientOptions) azruntime.Pipeline {
if options == nil {
options = &ClientOptions{}
}
Expand All @@ -98,6 +120,7 @@ func newPipeline(authPolicy policy.Policy, options *ClientOptions) azruntime.Pip
&headerPolicies{
enableContentResponseOnWrite: options.EnableContentResponseOnWrite,
},
&globalEndpointManagerPolicy{gem: gem},
},
PerRetry: []policy.Policy{
authPolicy,
Expand All @@ -106,6 +129,19 @@ func newPipeline(authPolicy policy.Policy, options *ClientOptions) azruntime.Pip
&options.ClientOptions)
}

func newInternalPipeline(authPolicy policy.Policy, options *ClientOptions) azruntime.Pipeline {
if options == nil {
options = &ClientOptions{}
}
return azruntime.NewPipeline("azcosmos", serviceLibVersion,
azruntime.PipelineOptions{
PerRetry: []policy.Policy{
authPolicy,
},
},
&options.ClientOptions)
}

func createScopeFromEndpoint(endpoint string) ([]string, error) {
u, err := url.Parse(endpoint)
if err != nil {
Expand Down Expand Up @@ -394,7 +430,7 @@ func (c *Client) createRequest(
}

req.Raw().Header.Set(headerXmsDate, time.Now().UTC().Format(http.TimeFormat))
req.Raw().Header.Set(headerXmsVersion, "2020-11-05")
req.Raw().Header.Set(headerXmsVersion, apiVersion)
req.Raw().Header.Set(cosmosHeaderSDKSupportedCapabilities, supportedCapabilitiesHeaderValue)

req.SetOperationValue(operationContext)
Expand Down
4 changes: 2 additions & 2 deletions sdk/data/azcosmos/cosmos_client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -254,8 +254,8 @@ func TestCreateRequest(t *testing.T) {
t.Errorf("Expected %v, but got %v", "", req.Raw().Header.Get(headerXmsDate))
}

if req.Raw().Header.Get(headerXmsVersion) != "2020-11-05" {
t.Errorf("Expected %v, but got %v", "2020-11-05", req.Raw().Header.Get(headerXmsVersion))
if req.Raw().Header.Get(headerXmsVersion) != apiVersion {
t.Errorf("Expected %v, but got %v", apiVersion, req.Raw().Header.Get(headerXmsVersion))
}

if req.Raw().Header.Get(cosmosHeaderSDKSupportedCapabilities) != supportedCapabilitiesHeaderValue {
Expand Down
40 changes: 26 additions & 14 deletions sdk/data/azcosmos/cosmos_global_endpoint_manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,17 @@ import (
const defaultUnavailableLocationRefreshInterval = 5 * time.Minute

type globalEndpointManager struct {
client *Client
clientEndpoint string
pipeline azruntime.Pipeline
preferredLocations []string
locationCache *locationCache
refreshTimeInterval time.Duration
gemMutex sync.Mutex
lastUpdateTime time.Time
}

func newGlobalEndpointManager(client *Client, preferredLocations []string, refreshTimeInterval time.Duration) (*globalEndpointManager, error) {
endpoint, err := url.Parse(client.endpoint)
func newGlobalEndpointManager(clientEndpoint string, pipeline azruntime.Pipeline, preferredLocations []string, refreshTimeInterval time.Duration) (*globalEndpointManager, error) {
endpoint, err := url.Parse(clientEndpoint)
if err != nil {
return &globalEndpointManager{}, err
}
Expand All @@ -36,7 +37,8 @@ func newGlobalEndpointManager(client *Client, preferredLocations []string, refre
}

gem := &globalEndpointManager{
client: client,
clientEndpoint: clientEndpoint,
pipeline: pipeline,
preferredLocations: preferredLocations,
locationCache: newLocationCache(preferredLocations, *endpoint),
refreshTimeInterval: refreshTimeInterval,
Expand Down Expand Up @@ -110,24 +112,34 @@ func (gem *globalEndpointManager) GetAccountProperties(ctx context.Context) (acc
resourceAddress: "",
}

path, err := generatePathForNameBased(resourceTypeDatabaseAccount, "", false)
ctxt, cancel := context.WithTimeout(ctx, 60*time.Second)
defer cancel()
req, err := azruntime.NewRequest(ctxt, http.MethodGet, gem.clientEndpoint)
if err != nil {
return accountProperties{}, fmt.Errorf("failed to generate path for name-based request: %v", err)
return accountProperties{}, err
}

ctx, cancel := context.WithTimeout(ctx, 60*time.Second)
azResponse, err := gem.client.sendGetRequest(path, ctx, operationContext, nil, nil)
cancel()
req.Raw().Header.Set(headerXmsDate, time.Now().UTC().Format(http.TimeFormat))
req.Raw().Header.Set(headerXmsVersion, apiVersion)
req.Raw().Header.Set(cosmosHeaderSDKSupportedCapabilities, supportedCapabilitiesHeaderValue)

req.SetOperationValue(operationContext)

azResponse, err := gem.pipeline.Do(req)
if err != nil {
return accountProperties{}, fmt.Errorf("failed to retrieve account properties: %v", err)
return accountProperties{}, err
}

properties, err := newAccountProperties(azResponse)
if err != nil {
return accountProperties{}, fmt.Errorf("failed to parse account properties: %v", err)
successResponse := (azResponse.StatusCode >= 200 && azResponse.StatusCode < 300)
if successResponse {
properties, err := newAccountProperties(azResponse)
if err != nil {
return accountProperties{}, fmt.Errorf("failed to parse account properties: %v", err)
}
return properties, nil
}

return properties, nil
return accountProperties{}, newCosmosError(azResponse)
}

func newAccountProperties(azResponse *http.Response) (accountProperties, error) {
Expand Down
25 changes: 25 additions & 0 deletions sdk/data/azcosmos/cosmos_global_endpoint_manager_policy.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

package azcosmos

import (
"context"
"net/http"

"github.com/Azure/azure-sdk-for-go/sdk/azcore/policy"
)

type globalEndpointManagerPolicy struct {
gem *globalEndpointManager
}

func (p *globalEndpointManagerPolicy) Do(req *policy.Request) (*http.Response, error) {
shouldRefresh := p.gem.ShouldRefresh()
if shouldRefresh {
go func() {
_ = p.gem.Update(context.Background())
}()
}
return req.Next()
}
51 changes: 16 additions & 35 deletions sdk/data/azcosmos/cosmos_global_endpoint_manager_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,7 @@ func TestGlobalEndpointManagerGetWriteEndpoints(t *testing.T) {

pl := azruntime.NewPipeline("azcosmostest", "v1.0.0", azruntime.PipelineOptions{}, &policy.ClientOptions{Transport: srv})

client := &Client{endpoint: srv.URL(), pipeline: pl}

preferredRegions := []string{"West US", "Central US"}

gem, err := newGlobalEndpointManager(client, preferredRegions, 5*time.Minute)
gem, err := newGlobalEndpointManager(srv.URL(), pl, []string{"West US", "Central US"}, 5*time.Minute)
assert.NoError(t, err)

writeEndpoints, err := gem.GetWriteEndpoints()
Expand All @@ -50,6 +46,7 @@ func TestGlobalEndpointManagerGetWriteEndpoints(t *testing.T) {
expectedWriteEndpoints := []url.URL{
*serverEndpoint,
}

assert.Equal(t, expectedWriteEndpoints, writeEndpoints)
}

Expand All @@ -60,11 +57,7 @@ func TestGlobalEndpointManagerGetReadEndpoints(t *testing.T) {

pl := azruntime.NewPipeline("azcosmostest", "v1.0.0", azruntime.PipelineOptions{}, &policy.ClientOptions{Transport: srv})

client := &Client{endpoint: srv.URL(), pipeline: pl}

preferredRegions := []string{"West US", "Central US"}

gem, err := newGlobalEndpointManager(client, preferredRegions, 5*time.Minute)
gem, err := newGlobalEndpointManager(srv.URL(), pl, []string{"West US", "Central US"}, 5*time.Minute)
assert.NoError(t, err)

readEndpoints, err := gem.GetReadEndpoints()
Expand All @@ -88,12 +81,10 @@ func TestGlobalEndpointManagerMarkEndpointUnavailableForRead(t *testing.T) {

client := &Client{endpoint: srv.URL(), pipeline: pl}

preferredRegions := []string{"West US", "Central US"}

gem, err := newGlobalEndpointManager(client, preferredRegions, 5*time.Minute)
endpoint, err := url.Parse(client.endpoint)
assert.NoError(t, err)

endpoint, err := url.Parse(client.endpoint)
gem, err := newGlobalEndpointManager(srv.URL(), pl, []string{"West US", "Central US"}, 5*time.Minute)
assert.NoError(t, err)

err = gem.MarkEndpointUnavailableForRead(*endpoint)
Expand All @@ -112,12 +103,10 @@ func TestGlobalEndpointManagerMarkEndpointUnavailableForWrite(t *testing.T) {

client := &Client{endpoint: srv.URL(), pipeline: pl}

preferredRegions := []string{"West US", "Central US"}

gem, err := newGlobalEndpointManager(client, preferredRegions, 5*time.Minute)
endpoint, err := url.Parse(client.endpoint)
assert.NoError(t, err)

endpoint, err := url.Parse(client.endpoint)
gem, err := newGlobalEndpointManager(srv.URL(), pl, []string{"West US", "Central US"}, 5*time.Minute)
assert.NoError(t, err)

err = gem.MarkEndpointUnavailableForWrite(*endpoint)
Expand All @@ -130,7 +119,6 @@ func TestGlobalEndpointManagerMarkEndpointUnavailableForWrite(t *testing.T) {
func TestGlobalEndpointManagerGetEndpointLocation(t *testing.T) {
srv, close := mock.NewTLSServer()
defer close()
srv.SetResponse(mock.WithStatusCode(http.StatusOK))

westRegion := accountRegion{
Name: "West US",
Expand All @@ -144,19 +132,17 @@ func TestGlobalEndpointManagerGetEndpointLocation(t *testing.T) {
}

jsonString, err := json.Marshal(properties)
if err != nil {
t.Fatal(err)
}
assert.NoError(t, err)

srv.SetResponse(mock.WithStatusCode(200))
srv.SetResponse(mock.WithBody(jsonString))

pl := azruntime.NewPipeline("azcosmostest", "v1.0.0", azruntime.PipelineOptions{}, &policy.ClientOptions{Transport: srv})

client := &Client{endpoint: srv.URL(), pipeline: pl}

gem, err := newGlobalEndpointManager(client, []string{}, 5*time.Minute)
serverEndpoint, err := url.Parse(srv.URL())
assert.NoError(t, err)

serverEndpoint, err := url.Parse(srv.URL())
gem, err := newGlobalEndpointManager(srv.URL(), pl, []string{}, 5*time.Minute)
assert.NoError(t, err)

err = gem.Update(context.Background())
Expand All @@ -175,11 +161,7 @@ func TestGlobalEndpointManagerGetAccountProperties(t *testing.T) {

pl := azruntime.NewPipeline("azcosmostest", "v1.0.0", azruntime.PipelineOptions{}, &policy.ClientOptions{Transport: srv})

client := &Client{endpoint: srv.URL(), pipeline: pl}

preferredRegions := []string{"West US", "Central US"}

gem, err := newGlobalEndpointManager(client, preferredRegions, 5*time.Minute)
gem, err := newGlobalEndpointManager(srv.URL(), pl, []string{"West US", "Central US"}, 5*time.Minute)
assert.NoError(t, err)

accountProps, err := gem.GetAccountProperties(context.Background())
Expand Down Expand Up @@ -212,13 +194,13 @@ func TestGlobalEndpointManagerCanUseMultipleWriteLocations(t *testing.T) {
mockLc.useMultipleWriteLocations = true

mockGem := globalEndpointManager{
client: client,
clientEndpoint: client.endpoint,
preferredLocations: preferredRegions,
locationCache: mockLc,
refreshTimeInterval: 5 * time.Minute,
}

gem, err := newGlobalEndpointManager(client, preferredRegions, 5*time.Minute)
gem, err := newGlobalEndpointManager(srv.URL(), pl, []string{}, 5*time.Minute)
assert.NoError(t, err)

// Multiple locations should be false for default GEM
Expand Down Expand Up @@ -254,9 +236,8 @@ func TestGlobalEndpointManagerConcurrentUpdate(t *testing.T) {
srv.SetResponse(mock.WithBody(jsonString))

pl := azruntime.NewPipeline("azcosmostest", "v1.0.0", azruntime.PipelineOptions{PerCall: []policy.Policy{countPolicy}}, &policy.ClientOptions{Transport: srv})
client := &Client{endpoint: srv.URL(), pipeline: pl}

gem, err := newGlobalEndpointManager(client, []string{}, 5*time.Second)
gem, err := newGlobalEndpointManager(srv.URL(), pl, []string{}, 5*time.Second)
assert.NoError(t, err)

// Call update concurrently and see how many times the policy gets called
Expand Down
35 changes: 33 additions & 2 deletions sdk/data/azcosmos/emulator_cosmos_global_endpoint_manager_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ func TestGlobalEndpointManagerEmulator(t *testing.T) {
preferredRegions := []string{}
emulatorRegion := accountRegion{Name: emulatorRegionName, Endpoint: "https://127.0.0.1:8081/"}

gem, err := newGlobalEndpointManager(client, preferredRegions, 5*time.Minute)
gem, err := newGlobalEndpointManager(client.endpoint, client.pipeline, preferredRegions, 5*time.Minute)
assert.NoError(t, err)

accountProps, err := gem.GetAccountProperties(context.Background())
Expand Down Expand Up @@ -61,7 +61,7 @@ func TestGlobalEndpointManagerEmulator(t *testing.T) {
assert.Equal(t, locationInfo.availReadEndpointsByLocation, availableEndpointsByLocation)
assert.Equal(t, locationInfo.availWriteEndpointsByLocation, availableEndpointsByLocation)

//update and assert available locations are now populated in location cache
// Run Update() and assert available locations are now populated in location cache
err = gem.Update(context.Background())
assert.NoError(t, err)
locationInfo = gem.locationCache.locationInfo
Expand All @@ -73,3 +73,34 @@ func TestGlobalEndpointManagerEmulator(t *testing.T) {
assert.Equal(t, len(locationInfo.availReadEndpointsByLocation), len(availableEndpointsByLocation)+1)
assert.Equal(t, len(locationInfo.availWriteEndpointsByLocation), len(availableEndpointsByLocation)+1)
}

func TestGlobalEndpointManagerPolicyEmulator(t *testing.T) {
emulatorTests := newEmulatorTests(t)
client := emulatorTests.getClient(t)
emulatorRegionName := "South Central US"

// Assert location cache is not populated until update() is called within the policy
locationInfo := client.gem.locationCache.locationInfo
availableLocation := []string{}
availableEndpointsByLocation := map[string]url.URL{}

assert.Equal(t, locationInfo.availReadLocations, availableLocation)
assert.Equal(t, locationInfo.availWriteLocations, availableLocation)
assert.Equal(t, locationInfo.availReadEndpointsByLocation, availableEndpointsByLocation)
assert.Equal(t, locationInfo.availWriteEndpointsByLocation, availableEndpointsByLocation)

// Assert that information gets populated by the gem policy after running an http request (read item)
db, _ := client.NewDatabase("database_id")
container, _ := db.NewContainer("container_id")
_, err := container.ReadItem(context.TODO(), NewPartitionKeyString("1"), "doc1", nil)
assert.Error(t, err)

locationInfo = client.gem.locationCache.locationInfo

assert.Equal(t, len(locationInfo.availReadLocations), len(availableLocation)+1)
assert.Equal(t, len(locationInfo.availWriteLocations), len(availableLocation)+1)
assert.Equal(t, locationInfo.availWriteLocations[0], emulatorRegionName)
assert.Equal(t, locationInfo.availReadLocations[0], emulatorRegionName)
assert.Equal(t, len(locationInfo.availReadEndpointsByLocation), len(availableEndpointsByLocation)+1)
assert.Equal(t, len(locationInfo.availWriteEndpointsByLocation), len(availableEndpointsByLocation)+1)
}
Loading

0 comments on commit 58ca6be

Please sign in to comment.