Skip to content

Commit

Permalink
feat: add CredentialProvider constructor for momento-local connections (
Browse files Browse the repository at this point in the history
#576)

* feat: add CredentialProvider constructor for momento-local connections

* fix: refactor data client grpc manager

* fix unit tests

* remove eager connect from data grpc manager request
  • Loading branch information
anitarua authored Jan 17, 2025
1 parent c6510e2 commit ebdbe74
Show file tree
Hide file tree
Showing 13 changed files with 198 additions and 111 deletions.
154 changes: 117 additions & 37 deletions auth/credential_provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,62 +12,93 @@ import (
"github.com/momentohq/client-sdk-go/internal/momentoerrors"
)

type Endpoints struct {
type Endpoint struct {
// Endpoint is the host which the Momento client will connect to
Endpoint string
// InsecureConnection is a flag to indicate whether the connection to the endpoint should be insecure. The zero value for a bool in Go is false, so we default to a secure connection (InsecureConnection==false) if a value is not provided.
InsecureConnection bool
}

type AllEndpoints struct {
// ControlEndpoint is the host which the Momento client will connect to the Momento control plane
ControlEndpoint string
ControlEndpoint Endpoint
// CacheEndpoint is the host which the Momento client will connect to the Momento data plane
CacheEndpoint string
CacheEndpoint Endpoint
// TokenEndpoint is the host which the Momento client will connect to for generating disposable auth tokens
TokenEndpoint string
TokenEndpoint Endpoint
// StorageEndpoint is the host which the Momento client will connect to the Momento storage data plane
StorageEndpoint string
StorageEndpoint Endpoint
}

type tokenAndEndpoints struct {
Endpoints
Endpoints AllEndpoints
AuthToken string
}

type CredentialProvider interface {
GetAuthToken() string
GetControlEndpoint() string
IsControlEndpointSecure() bool
GetCacheEndpoint() string
IsCacheEndpointSecure() bool
GetTokenEndpoint() string
IsTokenEndpointSecure() bool
GetStorageEndpoint() string
WithEndpoints(endpoints Endpoints) (CredentialProvider, error)
IsStorageEndpointSecure() bool
WithEndpoints(endpoints AllEndpoints) (CredentialProvider, error)
}

type defaultCredentialProvider struct {
authToken string
controlEndpoint string
cacheEndpoint string
tokenEndpoint string
storageEndpoint string
controlEndpoint Endpoint
cacheEndpoint Endpoint
tokenEndpoint Endpoint
storageEndpoint Endpoint
}

// GetAuthToken returns user's auth token.
func (credentialProvider defaultCredentialProvider) GetAuthToken() string {
return credentialProvider.authToken
}

// GetControlEndpoint returns Endpoints.ControlEndpoint.
// GetControlEndpoint returns AllEndpoints.ControlEndpoint.Endpoint.
func (credentialProvider defaultCredentialProvider) GetControlEndpoint() string {
return credentialProvider.controlEndpoint
return credentialProvider.controlEndpoint.Endpoint
}

// IsControlEndpointSecure returns true if the control endpoint is secure.
func (credentialProvider defaultCredentialProvider) IsControlEndpointSecure() bool {
return !credentialProvider.controlEndpoint.InsecureConnection
}

// GetCacheEndpoint returns Endpoints.CacheEndpoint.
// GetCacheEndpoint returns AllEndpoints.CacheEndpoint.Endpoint.
func (credentialProvider defaultCredentialProvider) GetCacheEndpoint() string {
return credentialProvider.cacheEndpoint
return credentialProvider.cacheEndpoint.Endpoint
}

// GetTokenEndpoint returns Endpoints.TokenEndpoint.
// IsCacheEndpointSecure returns true if the cace endpoint is secure.
func (credentialProvider defaultCredentialProvider) IsCacheEndpointSecure() bool {
return !credentialProvider.cacheEndpoint.InsecureConnection
}

// GetTokenEndpoint returns AllEndpoints.TokenEndpoint.Endpoint.
func (credentialProvider defaultCredentialProvider) GetTokenEndpoint() string {
return credentialProvider.tokenEndpoint
return credentialProvider.tokenEndpoint.Endpoint
}

// IsTokenEndpointSecure returns true if the token endpoint is secure.
func (credentialProvider defaultCredentialProvider) IsTokenEndpointSecure() bool {
return !credentialProvider.tokenEndpoint.InsecureConnection
}

// GetStorageEndpoint returns Endpoints.StorageEndpoint.
// GetStorageEndpoint returns AllEndpoints.StorageEndpoint.Endpoint.
func (credentialProvider defaultCredentialProvider) GetStorageEndpoint() string {
return credentialProvider.storageEndpoint
return credentialProvider.storageEndpoint.Endpoint
}

// IsStorageEndpointSecure returns true if the storage endpoint is secure.
func (credentialProvider defaultCredentialProvider) IsStorageEndpointSecure() bool {
return !credentialProvider.storageEndpoint.InsecureConnection
}

// FromEnvironmentVariable returns a new CredentialProvider using an auth token stored in the provided environment variable.
Expand All @@ -91,17 +122,17 @@ func FromString(authToken string) (CredentialProvider, error) {
// WithEndpoints overrides the cache and control endpoint URIs with those provided by the supplied Endpoints struct
// and returns a CredentialProvider with the new endpoint values. An endpoint supplied as an empty string is ignored
// and the existing value for that endpoint is retained.
func (credentialProvider defaultCredentialProvider) WithEndpoints(endpoints Endpoints) (CredentialProvider, error) {
if endpoints.CacheEndpoint != "" {
func (credentialProvider defaultCredentialProvider) WithEndpoints(endpoints AllEndpoints) (CredentialProvider, error) {
if endpoints.CacheEndpoint.Endpoint != "" {
credentialProvider.cacheEndpoint = endpoints.CacheEndpoint
}
if endpoints.ControlEndpoint != "" {
if endpoints.ControlEndpoint.Endpoint != "" {
credentialProvider.controlEndpoint = endpoints.ControlEndpoint
}
if endpoints.TokenEndpoint != "" {
if endpoints.TokenEndpoint.Endpoint != "" {
credentialProvider.tokenEndpoint = endpoints.TokenEndpoint
}
if endpoints.StorageEndpoint != "" {
if endpoints.StorageEndpoint.Endpoint != "" {
credentialProvider.storageEndpoint = endpoints.StorageEndpoint
}
return credentialProvider, nil
Expand All @@ -128,12 +159,21 @@ func NewStringMomentoTokenProvider(authToken string) (CredentialProvider, error)
if err != nil {
return nil, err
}
port := 443
provider := defaultCredentialProvider{
authToken: tokenAndEndpoints.AuthToken,
controlEndpoint: tokenAndEndpoints.ControlEndpoint,
cacheEndpoint: tokenAndEndpoints.CacheEndpoint,
tokenEndpoint: tokenAndEndpoints.TokenEndpoint,
storageEndpoint: tokenAndEndpoints.StorageEndpoint,
authToken: tokenAndEndpoints.AuthToken,
controlEndpoint: Endpoint{
Endpoint: fmt.Sprintf("%s:%d", tokenAndEndpoints.Endpoints.ControlEndpoint.Endpoint, port),
},
cacheEndpoint: Endpoint{
Endpoint: fmt.Sprintf("%s:%d", tokenAndEndpoints.Endpoints.CacheEndpoint.Endpoint, port),
},
tokenEndpoint: Endpoint{
Endpoint: fmt.Sprintf("%s:%d", tokenAndEndpoints.Endpoints.TokenEndpoint.Endpoint, port),
},
storageEndpoint: Endpoint{
Endpoint: fmt.Sprintf("%s:%d", tokenAndEndpoints.Endpoints.StorageEndpoint.Endpoint, port),
},
}
return provider, nil
}
Expand Down Expand Up @@ -165,11 +205,11 @@ func processV1Token(decodedBase64Token []byte) (*tokenAndEndpoints, momentoerror
}

return &tokenAndEndpoints{
Endpoints: Endpoints{
ControlEndpoint: fmt.Sprintf("control.%s", tokenData["endpoint"]),
CacheEndpoint: fmt.Sprintf("cache.%s", tokenData["endpoint"]),
TokenEndpoint: fmt.Sprintf("token.%s", tokenData["endpoint"]),
StorageEndpoint: fmt.Sprintf("storage.%s", tokenData["endpoint"]),
Endpoints: AllEndpoints{
ControlEndpoint: Endpoint{Endpoint: fmt.Sprintf("control.%s", tokenData["endpoint"])},
CacheEndpoint: Endpoint{Endpoint: fmt.Sprintf("cache.%s", tokenData["endpoint"])},
TokenEndpoint: Endpoint{Endpoint: fmt.Sprintf("token.%s", tokenData["endpoint"])},
StorageEndpoint: Endpoint{Endpoint: fmt.Sprintf("storage.%s", tokenData["endpoint"])},
},
AuthToken: tokenData["api_key"],
}, nil
Expand All @@ -188,9 +228,9 @@ func processJwtToken(authToken string) (*tokenAndEndpoints, momentoerrors.Moment
controlEndpoint := reflect.ValueOf(claims["cp"]).String()
cacheEndpoint := reflect.ValueOf(claims["c"]).String()
return &tokenAndEndpoints{
Endpoints: Endpoints{
ControlEndpoint: controlEndpoint,
CacheEndpoint: cacheEndpoint,
Endpoints: AllEndpoints{
ControlEndpoint: Endpoint{Endpoint: controlEndpoint},
CacheEndpoint: Endpoint{Endpoint: cacheEndpoint},
},
AuthToken: authToken,
}, nil
Expand All @@ -201,3 +241,43 @@ func processJwtToken(authToken string) (*tokenAndEndpoints, momentoerrors.Moment
nil,
)
}

type MomentoLocalConfig struct {
Hostname string
Port uint
}

func NewMomentoLocalProvider(config *MomentoLocalConfig) (CredentialProvider, error) {
hostname := "127.0.0.1"
port := uint(8080)
if config != nil {
if config.Hostname != "" {
hostname = config.Hostname
}
if config.Port != 0 {
port = config.Port
}
}

momentoLocalEndpoint := Endpoint{
Endpoint: fmt.Sprintf("%s:%d", hostname, port),
InsecureConnection: true,
}
tokenAndEndpoints := &tokenAndEndpoints{
Endpoints: AllEndpoints{
ControlEndpoint: momentoLocalEndpoint,
CacheEndpoint: momentoLocalEndpoint,
TokenEndpoint: momentoLocalEndpoint,
StorageEndpoint: momentoLocalEndpoint,
},
}

provider := defaultCredentialProvider{
authToken: tokenAndEndpoints.AuthToken,
controlEndpoint: tokenAndEndpoints.Endpoints.ControlEndpoint,
cacheEndpoint: tokenAndEndpoints.Endpoints.CacheEndpoint,
tokenEndpoint: tokenAndEndpoints.Endpoints.TokenEndpoint,
storageEndpoint: tokenAndEndpoints.Endpoints.StorageEndpoint,
}
return provider, nil
}
59 changes: 48 additions & 11 deletions auth/credential_provider_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,32 +51,32 @@ var _ = Describe("auth credential-provider", func() {
credentialProvider, err := auth.NewEnvMomentoTokenProvider(envVar)
Expect(err).To(BeNil())
Expect(credentialProvider.GetAuthToken()).To(Equal(testV1ApiKey))
Expect(credentialProvider.GetCacheEndpoint()).To(Equal("cache.test.momentohq.com"))
Expect(credentialProvider.GetControlEndpoint()).To(Equal("control.test.momentohq.com"))
Expect(credentialProvider.GetCacheEndpoint()).To(Equal("cache.test.momentohq.com:443"))
Expect(credentialProvider.GetControlEndpoint()).To(Equal("control.test.momentohq.com:443"))
})

It("returns a credential provider from a string via constructor", func() {
credentialProvider, err := auth.NewStringMomentoTokenProvider(os.Getenv(envVar))
Expect(err).To(BeNil())
Expect(credentialProvider.GetAuthToken()).To(Equal(testV1ApiKey))
Expect(credentialProvider.GetCacheEndpoint()).To(Equal("cache.test.momentohq.com"))
Expect(credentialProvider.GetControlEndpoint()).To(Equal("control.test.momentohq.com"))
Expect(credentialProvider.GetCacheEndpoint()).To(Equal("cache.test.momentohq.com:443"))
Expect(credentialProvider.GetControlEndpoint()).To(Equal("control.test.momentohq.com:443"))
})

It("returns a credential provider from an environment variable via method", func() {
credentialProvider, err := auth.FromEnvironmentVariable(envVar)
Expect(err).To(BeNil())
Expect(credentialProvider.GetAuthToken()).To(Equal(testV1ApiKey))
Expect(credentialProvider.GetCacheEndpoint()).To(Equal("cache.test.momentohq.com"))
Expect(credentialProvider.GetControlEndpoint()).To(Equal("control.test.momentohq.com"))
Expect(credentialProvider.GetCacheEndpoint()).To(Equal("cache.test.momentohq.com:443"))
Expect(credentialProvider.GetControlEndpoint()).To(Equal("control.test.momentohq.com:443"))
})

It("returns a credential provider from a string via method", func() {
credentialProvider, err := auth.FromString(os.Getenv(envVar))
Expect(err).To(BeNil())
Expect(credentialProvider.GetAuthToken()).To(Equal(testV1ApiKey))
Expect(credentialProvider.GetCacheEndpoint()).To(Equal("cache.test.momentohq.com"))
Expect(credentialProvider.GetControlEndpoint()).To(Equal("control.test.momentohq.com"))
Expect(credentialProvider.GetCacheEndpoint()).To(Equal("cache.test.momentohq.com:443"))
Expect(credentialProvider.GetControlEndpoint()).To(Equal("control.test.momentohq.com:443"))
})

It("overrides endpoints", func() {
Expand All @@ -86,18 +86,22 @@ var _ = Describe("auth credential-provider", func() {
cacheEndpoint := credentialProvider.GetCacheEndpoint()
Expect(controlEndpoint).ToNot(BeEmpty())
Expect(cacheEndpoint).ToNot(BeEmpty())
Expect(credentialProvider.IsControlEndpointSecure()).To(BeTrue())
Expect(credentialProvider.IsCacheEndpointSecure()).To(BeTrue())

controlEndpoint = fmt.Sprintf("%s-overridden", controlEndpoint)
cacheEndpoint = fmt.Sprintf("%s-overridden", cacheEndpoint)
credentialProvider, err = credentialProvider.WithEndpoints(
auth.Endpoints{
ControlEndpoint: controlEndpoint,
CacheEndpoint: cacheEndpoint,
auth.AllEndpoints{
ControlEndpoint: auth.Endpoint{Endpoint: controlEndpoint},
CacheEndpoint: auth.Endpoint{Endpoint: cacheEndpoint},
},
)
Expect(err).To(BeNil())
Expect(credentialProvider.GetControlEndpoint()).To(Equal(controlEndpoint))
Expect(credentialProvider.GetCacheEndpoint()).To(Equal(cacheEndpoint))
Expect(credentialProvider.IsControlEndpointSecure()).To(BeTrue())
Expect(credentialProvider.IsCacheEndpointSecure()).To(BeTrue())
})

DescribeTable("errors when v1 token is missing data",
Expand All @@ -115,6 +119,39 @@ var _ = Describe("auth credential-provider", func() {
Entry("missing endpoint", testV1MissingEndpoint),
Entry("missing api key", testV1MissingApiKey),
)

It("correctly sets Momento Local endpoints", func() {
// Using default config
credentialProvider, err := auth.NewMomentoLocalProvider(nil)
defaultEndpoint := "127.0.0.1:8080"
Expect(err).To(BeNil())
Expect(credentialProvider.GetAuthToken()).To(Equal(""))
Expect(credentialProvider.GetCacheEndpoint()).To(Equal(defaultEndpoint))
Expect(credentialProvider.IsCacheEndpointSecure()).To(BeFalse())
Expect(credentialProvider.GetControlEndpoint()).To(Equal(defaultEndpoint))
Expect(credentialProvider.IsControlEndpointSecure()).To(BeFalse())
Expect(credentialProvider.GetStorageEndpoint()).To(Equal(defaultEndpoint))
Expect(credentialProvider.IsStorageEndpointSecure()).To(BeFalse())
Expect(credentialProvider.GetTokenEndpoint()).To(Equal(defaultEndpoint))
Expect(credentialProvider.IsTokenEndpointSecure()).To(BeFalse())

// Using provided config
credentialProvider, err = auth.NewMomentoLocalProvider(&auth.MomentoLocalConfig{
Hostname: "localhost",
Port: 9090,
})
nonDefaultEndpoint := "localhost:9090"
Expect(err).To(BeNil())
Expect(credentialProvider.GetAuthToken()).To(Equal(""))
Expect(credentialProvider.GetCacheEndpoint()).To(Equal(nonDefaultEndpoint))
Expect(credentialProvider.IsCacheEndpointSecure()).To(BeFalse())
Expect(credentialProvider.GetControlEndpoint()).To(Equal(nonDefaultEndpoint))
Expect(credentialProvider.IsControlEndpointSecure()).To(BeFalse())
Expect(credentialProvider.GetStorageEndpoint()).To(Equal(nonDefaultEndpoint))
Expect(credentialProvider.IsStorageEndpointSecure()).To(BeFalse())
Expect(credentialProvider.GetTokenEndpoint()).To(Equal(nonDefaultEndpoint))
Expect(credentialProvider.IsTokenEndpointSecure()).To(BeFalse())
})
})

})
7 changes: 2 additions & 5 deletions internal/grpcmanagers/auth_manager.go
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
package grpcmanagers

import (
"fmt"

"github.com/momentohq/client-sdk-go/internal/interceptor"
"github.com/momentohq/client-sdk-go/internal/models"
"github.com/momentohq/client-sdk-go/internal/momentoerrors"
Expand All @@ -14,10 +12,8 @@ type AuthGrpcManager struct {
AuthToken string
}

const AuthPort = ":443"

func NewAuthGrpcManager(request *models.AuthGrpcManagerRequest) (*AuthGrpcManager, momentoerrors.MomentoSvcErr) {
endpoint := fmt.Sprint(request.CredentialProvider.GetControlEndpoint(), AuthPort)
endpoint := request.CredentialProvider.GetControlEndpoint()
authToken := request.CredentialProvider.GetAuthToken()

headerInterceptors := []grpc.UnaryClientInterceptor{
Expand All @@ -28,6 +24,7 @@ func NewAuthGrpcManager(request *models.AuthGrpcManagerRequest) (*AuthGrpcManage
endpoint,
AllDialOptions(
request.GrpcConfiguration,
request.CredentialProvider.IsControlEndpointSecure(),
grpc.WithChainUnaryInterceptor(headerInterceptors...),
)...,
)
Expand Down
Loading

0 comments on commit ebdbe74

Please sign in to comment.