Skip to content

Commit

Permalink
Respect expiration of password token source (#163)
Browse files Browse the repository at this point in the history
* Respect expiration of password token source

When using username/password authentication the Client fetches a token
on creation to be used as the token source. This token expires
eventually and is only refreshable by creating a new Client object.

Capture the expiration of the token and refresh the oauth client when it
expires.

fixes #34
updates cloudfoundry-community/stackdriver-tools#177

* grab token expiration after successful fetch
  • Loading branch information
johnsonj authored and lnguyen committed Jan 5, 2018
1 parent dc99b02 commit 2eadb63
Show file tree
Hide file tree
Showing 3 changed files with 97 additions and 54 deletions.
6 changes: 3 additions & 3 deletions cf_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ func testPostQuery(req *http.Request, postFormBody *string, t *testing.T) {
func setupMultiple(mockEndpoints []MockRoute, t *testing.T) {
mux = http.NewServeMux()
server = httptest.NewServer(mux)
fakeUAAServer = FakeUAAServer()
fakeUAAServer = FakeUAAServer(3)
m := martini.New()
m.Use(render.Renderer())
r := martini.NewRouter()
Expand Down Expand Up @@ -120,7 +120,7 @@ func setupMultiple(mockEndpoints []MockRoute, t *testing.T) {
mux.Handle("/", m)
}

func FakeUAAServer() *httptest.Server {
func FakeUAAServer(expiresIn int) *httptest.Server {
mux := http.NewServeMux()
server := httptest.NewServer(mux)
m := martini.New()
Expand All @@ -132,7 +132,7 @@ func FakeUAAServer() *httptest.Server {
"token_type": "bearer",
"access_token": "foobar" + strconv.Itoa(count),
"refresh_token": "barfoo",
"expires_in": 3,
"expires_in": expiresIn,
})
count = count + 1
})
Expand Down
117 changes: 67 additions & 50 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
"golang.org/x/net/context"
"golang.org/x/oauth2"
"golang.org/x/oauth2/clientcredentials"
"time"
)

//Client used to communicate with Cloud Foundry
Expand All @@ -29,16 +30,17 @@ type Endpoint struct {

//Config is used to configure the creation of a client
type Config struct {
ApiAddress string `json:"api_url"`
Username string `json:"user"`
Password string `json:"password"`
ClientID string `json:"client_id"`
ClientSecret string `json:"client_secret"`
SkipSslValidation bool `json:"skip_ssl_validation"`
HttpClient *http.Client
Token string `json:"auth_token"`
TokenSource oauth2.TokenSource
UserAgent string `json:"user_agent"`
ApiAddress string `json:"api_url"`
Username string `json:"user"`
Password string `json:"password"`
ClientID string `json:"client_id"`
ClientSecret string `json:"client_secret"`
SkipSslValidation bool `json:"skip_ssl_validation"`
HttpClient *http.Client
Token string `json:"auth_token"`
TokenSource oauth2.TokenSource
tokenSourceDeadline *time.Time
UserAgent string `json:"user_agent"`
}

// request is used to help build up a request
Expand All @@ -55,13 +57,13 @@ type request struct {
//Need to be remove in close future
func DefaultConfig() *Config {
return &Config{
ApiAddress: "http://api.bosh-lite.com",
Username: "admin",
Password: "admin",
Token: "",
SkipSslValidation: false,
HttpClient: http.DefaultClient,
UserAgent: "Go-CF-client/1.1",
ApiAddress: "http://api.bosh-lite.com",
Username: "admin",
Password: "admin",
Token: "",
SkipSslValidation: false,
HttpClient: http.DefaultClient,
UserAgent: "Go-CF-client/1.1",
}
}

Expand Down Expand Up @@ -125,37 +127,14 @@ func NewClient(config *Config) (client *Client, err error) {
tp.TLSClientConfig.InsecureSkipVerify = config.SkipSslValidation
}

// we want to keep the Timeout value from config.HttpClient
timeout := config.HttpClient.Timeout

ctx := context.Background()
ctx = context.WithValue(ctx, oauth2.HTTPClient, config.HttpClient)

endpoint, err := getInfo(config.ApiAddress, oauth2.NewClient(ctx, nil))

if err != nil {
return nil, errors.Wrap(err, "Could not get api /v2/info")
client = &Client{
Config: *config,
}

switch {
case config.Token != "":
config = getUserTokenAuth(ctx, config, endpoint)
case config.ClientID != "":
config = getClientAuth(ctx, config, endpoint)
default:
config, err = getUserAuth(ctx, config, endpoint)
if err != nil {
return nil, err
}
}
// make sure original Timeout value will be used
if config.HttpClient.Timeout != timeout {
config.HttpClient.Timeout = timeout
}
client = &Client{
Config: *config,
Endpoint: *endpoint,
if err := client.refreshEndpoint(); err != nil {
return nil, err
}

return client, nil
}

Expand All @@ -168,7 +147,7 @@ func shallowDefaultTransport() *http.Transport {
}
}

func getUserAuth(ctx context.Context, config *Config, endpoint *Endpoint) (*Config, error) {
func getUserAuth(ctx context.Context, config Config, endpoint *Endpoint) (Config, error) {
authConfig := &oauth2.Config{
ClientID: "cf",
Scopes: []string{""},
Expand All @@ -179,18 +158,18 @@ func getUserAuth(ctx context.Context, config *Config, endpoint *Endpoint) (*Conf
}

token, err := authConfig.PasswordCredentialsToken(ctx, config.Username, config.Password)

if err != nil {
return nil, errors.Wrap(err, "Error getting token")
return config, errors.Wrap(err, "Error getting token")
}

config.tokenSourceDeadline = &token.Expiry
config.TokenSource = authConfig.TokenSource(ctx, token)
config.HttpClient = oauth2.NewClient(ctx, config.TokenSource)

return config, err
}

func getClientAuth(ctx context.Context, config *Config, endpoint *Endpoint) *Config {
func getClientAuth(ctx context.Context, config Config, endpoint *Endpoint) Config {
authConfig := &clientcredentials.Config{
ClientID: config.ClientID,
ClientSecret: config.ClientSecret,
Expand All @@ -203,7 +182,7 @@ func getClientAuth(ctx context.Context, config *Config, endpoint *Endpoint) *Con
}

// getUserTokenAuth initializes client credentials from existing bearer token.
func getUserTokenAuth(ctx context.Context, config *Config, endpoint *Endpoint) *Config {
func getUserTokenAuth(ctx context.Context, config Config, endpoint *Endpoint) Config {
authConfig := &oauth2.Config{
ClientID: "cf",
Scopes: []string{""},
Expand Down Expand Up @@ -293,6 +272,38 @@ func (c *Client) DoRequest(r *request) (*http.Response, error) {
return resp, nil
}

func (c *Client) refreshEndpoint() error {
// we want to keep the Timeout value from config.HttpClient
timeout := c.Config.HttpClient.Timeout

ctx := context.Background()
ctx = context.WithValue(ctx, oauth2.HTTPClient, c.Config.HttpClient)

endpoint, err := getInfo(c.Config.ApiAddress, oauth2.NewClient(ctx, nil))

if err != nil {
return errors.Wrap(err, "Could not get api /v2/info")
}

switch {
case c.Config.Token != "":
c.Config = getUserTokenAuth(ctx, c.Config, endpoint)
case c.Config.ClientID != "":
c.Config = getClientAuth(ctx, c.Config, endpoint)
default:
c.Config, err = getUserAuth(ctx, c.Config, endpoint)
if err != nil {
return err
}
}
// make sure original Timeout value will be used
if c.Config.HttpClient.Timeout != timeout {
c.Config.HttpClient.Timeout = timeout
}

return nil
}

// toHTTP converts the request to an HTTP request
func (r *request) toHTTP() (*http.Request, error) {

Expand Down Expand Up @@ -327,6 +338,12 @@ func encodeBody(obj interface{}) (io.Reader, error) {
}

func (c *Client) GetToken() (string, error) {
if c.Config.tokenSourceDeadline != nil && c.Config.tokenSourceDeadline.Before(time.Now()) {
if err := c.refreshEndpoint(); err != nil {
return "", err
}
}

token, err := c.Config.TokenSource.Token()
if err != nil {
return "", errors.Wrap(err, "Error getting bearer token")
Expand Down
28 changes: 27 additions & 1 deletion client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ func TestTokenRefresh(t *testing.T) {
gomega.RegisterTestingT(t)
Convey("Test making request", t, func() {
setup(MockRoute{"GET", "/v2/organizations", listOrgsPayload, "", 200, "", nil}, t)
fakeUAAServer = FakeUAAServer(1)
c := &Config{
ApiAddress: server.URL,
Username: "foo",
Expand All @@ -93,6 +94,31 @@ func TestTokenRefresh(t *testing.T) {
So(err, ShouldBeNil)

gomega.Consistently(token).Should(gomega.Equal("bearer foobar2"))
// gomega.Eventually(client.GetToken(), "3s").Should(gomega.Equal("bearer foobar3"))
gomega.Eventually(func() string { token, _ := client.GetToken(); return token }, "2s").Should(gomega.Equal("bearer foobar3"))
})
}

func TestEndpointRefresh(t *testing.T) {
gomega.RegisterTestingT(t)
Convey("Test expiring endpoint", t, func() {
setup(MockRoute{"GET", "/v2/organizations", listOrgsPayload, "", 200, "", nil}, t)
fakeUAAServer = FakeUAAServer(0)

c := &Config{
ApiAddress: server.URL,
Username: "foo",
Password: "bar",
}

client, err := NewClient(c)
So(err, ShouldBeNil)

lastTokenSource := client.Config.TokenSource
for i := 1; i < 5; i++ {
_, err := client.GetToken()
So(err, ShouldBeNil)
So(client.Config.TokenSource, ShouldNotEqual, lastTokenSource)
lastTokenSource = client.Config.TokenSource
}
})
}

0 comments on commit 2eadb63

Please sign in to comment.