diff --git a/cf_test.go b/cf_test.go index 5234a491..61a9758d 100644 --- a/cf_test.go +++ b/cf_test.go @@ -68,7 +68,7 @@ func testPostQuery(req *http.Request, postFormBody *string, t *testing.T) { func setupMultiple(mockEndpoints []MockRoute, t *testing.T) { mux = http.NewServeMux() server = httptest.NewServer(mux) - fakeUAAServer = FakeUAAServer() + fakeUAAServer = FakeUAAServer(3) m := martini.New() m.Use(render.Renderer()) r := martini.NewRouter() @@ -120,7 +120,7 @@ func setupMultiple(mockEndpoints []MockRoute, t *testing.T) { mux.Handle("/", m) } -func FakeUAAServer() *httptest.Server { +func FakeUAAServer(expiresIn int) *httptest.Server { mux := http.NewServeMux() server := httptest.NewServer(mux) m := martini.New() @@ -132,7 +132,7 @@ func FakeUAAServer() *httptest.Server { "token_type": "bearer", "access_token": "foobar" + strconv.Itoa(count), "refresh_token": "barfoo", - "expires_in": 3, + "expires_in": expiresIn, }) count = count + 1 }) diff --git a/client.go b/client.go index b1dcbadf..ce8f3d2b 100644 --- a/client.go +++ b/client.go @@ -12,6 +12,7 @@ import ( "golang.org/x/net/context" "golang.org/x/oauth2" "golang.org/x/oauth2/clientcredentials" + "time" ) //Client used to communicate with Cloud Foundry @@ -29,16 +30,17 @@ type Endpoint struct { //Config is used to configure the creation of a client type Config struct { - ApiAddress string `json:"api_url"` - Username string `json:"user"` - Password string `json:"password"` - ClientID string `json:"client_id"` - ClientSecret string `json:"client_secret"` - SkipSslValidation bool `json:"skip_ssl_validation"` - HttpClient *http.Client - Token string `json:"auth_token"` - TokenSource oauth2.TokenSource - UserAgent string `json:"user_agent"` + ApiAddress string `json:"api_url"` + Username string `json:"user"` + Password string `json:"password"` + ClientID string `json:"client_id"` + ClientSecret string `json:"client_secret"` + SkipSslValidation bool `json:"skip_ssl_validation"` + HttpClient *http.Client + Token string `json:"auth_token"` + TokenSource oauth2.TokenSource + tokenSourceDeadline *time.Time + UserAgent string `json:"user_agent"` } // request is used to help build up a request @@ -55,13 +57,13 @@ type request struct { //Need to be remove in close future func DefaultConfig() *Config { return &Config{ - ApiAddress: "http://api.bosh-lite.com", - Username: "admin", - Password: "admin", - Token: "", - SkipSslValidation: false, - HttpClient: http.DefaultClient, - UserAgent: "Go-CF-client/1.1", + ApiAddress: "http://api.bosh-lite.com", + Username: "admin", + Password: "admin", + Token: "", + SkipSslValidation: false, + HttpClient: http.DefaultClient, + UserAgent: "Go-CF-client/1.1", } } @@ -125,37 +127,14 @@ func NewClient(config *Config) (client *Client, err error) { tp.TLSClientConfig.InsecureSkipVerify = config.SkipSslValidation } - // we want to keep the Timeout value from config.HttpClient - timeout := config.HttpClient.Timeout - - ctx := context.Background() - ctx = context.WithValue(ctx, oauth2.HTTPClient, config.HttpClient) - - endpoint, err := getInfo(config.ApiAddress, oauth2.NewClient(ctx, nil)) - - if err != nil { - return nil, errors.Wrap(err, "Could not get api /v2/info") + client = &Client{ + Config: *config, } - switch { - case config.Token != "": - config = getUserTokenAuth(ctx, config, endpoint) - case config.ClientID != "": - config = getClientAuth(ctx, config, endpoint) - default: - config, err = getUserAuth(ctx, config, endpoint) - if err != nil { - return nil, err - } - } - // make sure original Timeout value will be used - if config.HttpClient.Timeout != timeout { - config.HttpClient.Timeout = timeout - } - client = &Client{ - Config: *config, - Endpoint: *endpoint, + if err := client.refreshEndpoint(); err != nil { + return nil, err } + return client, nil } @@ -168,7 +147,7 @@ func shallowDefaultTransport() *http.Transport { } } -func getUserAuth(ctx context.Context, config *Config, endpoint *Endpoint) (*Config, error) { +func getUserAuth(ctx context.Context, config Config, endpoint *Endpoint) (Config, error) { authConfig := &oauth2.Config{ ClientID: "cf", Scopes: []string{""}, @@ -179,18 +158,18 @@ func getUserAuth(ctx context.Context, config *Config, endpoint *Endpoint) (*Conf } token, err := authConfig.PasswordCredentialsToken(ctx, config.Username, config.Password) - if err != nil { - return nil, errors.Wrap(err, "Error getting token") + return config, errors.Wrap(err, "Error getting token") } + config.tokenSourceDeadline = &token.Expiry config.TokenSource = authConfig.TokenSource(ctx, token) config.HttpClient = oauth2.NewClient(ctx, config.TokenSource) return config, err } -func getClientAuth(ctx context.Context, config *Config, endpoint *Endpoint) *Config { +func getClientAuth(ctx context.Context, config Config, endpoint *Endpoint) Config { authConfig := &clientcredentials.Config{ ClientID: config.ClientID, ClientSecret: config.ClientSecret, @@ -203,7 +182,7 @@ func getClientAuth(ctx context.Context, config *Config, endpoint *Endpoint) *Con } // getUserTokenAuth initializes client credentials from existing bearer token. -func getUserTokenAuth(ctx context.Context, config *Config, endpoint *Endpoint) *Config { +func getUserTokenAuth(ctx context.Context, config Config, endpoint *Endpoint) Config { authConfig := &oauth2.Config{ ClientID: "cf", Scopes: []string{""}, @@ -293,6 +272,38 @@ func (c *Client) DoRequest(r *request) (*http.Response, error) { return resp, nil } +func (c *Client) refreshEndpoint() error { + // we want to keep the Timeout value from config.HttpClient + timeout := c.Config.HttpClient.Timeout + + ctx := context.Background() + ctx = context.WithValue(ctx, oauth2.HTTPClient, c.Config.HttpClient) + + endpoint, err := getInfo(c.Config.ApiAddress, oauth2.NewClient(ctx, nil)) + + if err != nil { + return errors.Wrap(err, "Could not get api /v2/info") + } + + switch { + case c.Config.Token != "": + c.Config = getUserTokenAuth(ctx, c.Config, endpoint) + case c.Config.ClientID != "": + c.Config = getClientAuth(ctx, c.Config, endpoint) + default: + c.Config, err = getUserAuth(ctx, c.Config, endpoint) + if err != nil { + return err + } + } + // make sure original Timeout value will be used + if c.Config.HttpClient.Timeout != timeout { + c.Config.HttpClient.Timeout = timeout + } + + return nil +} + // toHTTP converts the request to an HTTP request func (r *request) toHTTP() (*http.Request, error) { @@ -327,6 +338,12 @@ func encodeBody(obj interface{}) (io.Reader, error) { } func (c *Client) GetToken() (string, error) { + if c.Config.tokenSourceDeadline != nil && c.Config.tokenSourceDeadline.Before(time.Now()) { + if err := c.refreshEndpoint(); err != nil { + return "", err + } + } + token, err := c.Config.TokenSource.Token() if err != nil { return "", errors.Wrap(err, "Error getting bearer token") diff --git a/client_test.go b/client_test.go index 688e51d1..682114ce 100644 --- a/client_test.go +++ b/client_test.go @@ -81,6 +81,7 @@ func TestTokenRefresh(t *testing.T) { gomega.RegisterTestingT(t) Convey("Test making request", t, func() { setup(MockRoute{"GET", "/v2/organizations", listOrgsPayload, "", 200, "", nil}, t) + fakeUAAServer = FakeUAAServer(1) c := &Config{ ApiAddress: server.URL, Username: "foo", @@ -93,6 +94,31 @@ func TestTokenRefresh(t *testing.T) { So(err, ShouldBeNil) gomega.Consistently(token).Should(gomega.Equal("bearer foobar2")) - // gomega.Eventually(client.GetToken(), "3s").Should(gomega.Equal("bearer foobar3")) + gomega.Eventually(func() string { token, _ := client.GetToken(); return token }, "2s").Should(gomega.Equal("bearer foobar3")) + }) +} + +func TestEndpointRefresh(t *testing.T) { + gomega.RegisterTestingT(t) + Convey("Test expiring endpoint", t, func() { + setup(MockRoute{"GET", "/v2/organizations", listOrgsPayload, "", 200, "", nil}, t) + fakeUAAServer = FakeUAAServer(0) + + c := &Config{ + ApiAddress: server.URL, + Username: "foo", + Password: "bar", + } + + client, err := NewClient(c) + So(err, ShouldBeNil) + + lastTokenSource := client.Config.TokenSource + for i := 1; i < 5; i++ { + _, err := client.GetToken() + So(err, ShouldBeNil) + So(client.Config.TokenSource, ShouldNotEqual, lastTokenSource) + lastTokenSource = client.Config.TokenSource + } }) }