Skip to content

Commit

Permalink
Modularized api token management in GRAPPA drivers (#1574)
Browse files Browse the repository at this point in the history
  • Loading branch information
jimil749 authored Mar 23, 2021
1 parent c5ec249 commit a6439ff
Show file tree
Hide file tree
Showing 4 changed files with 204 additions and 264 deletions.
5 changes: 5 additions & 0 deletions changelog/unreleased/modularize-api-token-management.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
Change: Modularize api token management in GRAPPA drivers

This PR moves the duplicated api token management methods into a seperate utils package

https://github.com/cs3org/reva/issues/1562
143 changes: 11 additions & 132 deletions pkg/cbox/group/rest/rest.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,23 +20,18 @@ package rest

import (
"context"
"encoding/json"
"errors"
"fmt"
"io/ioutil"
"net/http"
"net/url"
"regexp"
"strings"
"sync"
"time"

grouppb "github.com/cs3org/go-cs3apis/cs3/identity/group/v1beta1"
userpb "github.com/cs3org/go-cs3apis/cs3/identity/user/v1beta1"
"github.com/cs3org/reva/pkg/appctx"
utils "github.com/cs3org/reva/pkg/cbox/utils"
"github.com/cs3org/reva/pkg/group"
"github.com/cs3org/reva/pkg/group/manager/registry"
"github.com/cs3org/reva/pkg/rhttp"
"github.com/gomodule/redigo/redis"
"github.com/mitchellh/mapstructure"
)
Expand All @@ -50,17 +45,9 @@ var (
)

type manager struct {
conf *config
redisPool *redis.Pool
oidcToken OIDCToken
client *http.Client
}

// OIDCToken stores the OIDC token used to authenticate requests to the REST API service
type OIDCToken struct {
sync.Mutex // concurrent access to apiToken and tokenExpirationTime
apiToken string
tokenExpirationTime time.Time
conf *config
redisPool *redis.Pool
apiTokenManager *utils.APITokenManager
}

type config struct {
Expand Down Expand Up @@ -125,126 +112,18 @@ func New(m map[string]interface{}) (group.Manager, error) {
c.init()

redisPool := initRedisPool(c.RedisAddress, c.RedisUsername, c.RedisPassword)
apiTokenManager := utils.InitAPITokenManager(c.TargetAPI, c.OIDCTokenEndpoint, c.ClientID, c.ClientSecret)
return &manager{
conf: c,
redisPool: redisPool,
client: rhttp.GetHTTPClient(
rhttp.Timeout(10*time.Second),
rhttp.Insecure(true),
),
conf: c,
redisPool: redisPool,
apiTokenManager: apiTokenManager,
}, nil
}

func (m *manager) renewAPIToken(ctx context.Context, forceRenewal bool) error {
// Received tokens have an expiration time of 20 minutes.
// Take a couple of seconds as buffer time for the API call to complete
if forceRenewal || m.oidcToken.tokenExpirationTime.Before(time.Now().Add(time.Second*time.Duration(2))) {
token, expiration, err := m.getAPIToken(ctx)
if err != nil {
return err
}

m.oidcToken.Lock()
defer m.oidcToken.Unlock()

m.oidcToken.apiToken = token
m.oidcToken.tokenExpirationTime = expiration
}
return nil
}

func (m *manager) getAPIToken(ctx context.Context) (string, time.Time, error) {

params := url.Values{
"grant_type": {"client_credentials"},
"audience": {m.conf.TargetAPI},
}

httpReq, err := http.NewRequest("POST", m.conf.OIDCTokenEndpoint, strings.NewReader(params.Encode()))
if err != nil {
return "", time.Time{}, err
}
httpReq.SetBasicAuth(m.conf.ClientID, m.conf.ClientSecret)
httpReq.Header.Set("Content-Type", "application/x-www-form-urlencoded; param=value")

httpRes, err := m.client.Do(httpReq)
if err != nil {
return "", time.Time{}, err
}
defer httpRes.Body.Close()

body, err := ioutil.ReadAll(httpRes.Body)
if err != nil {
return "", time.Time{}, err
}
if httpRes.StatusCode < 200 || httpRes.StatusCode > 299 {
return "", time.Time{}, errors.New("rest: get token endpoint returned " + httpRes.Status)
}

var result map[string]interface{}
err = json.Unmarshal(body, &result)
if err != nil {
return "", time.Time{}, err
}

expirationSecs := result["expires_in"].(float64)
expirationTime := time.Now().Add(time.Second * time.Duration(expirationSecs))
return result["access_token"].(string), expirationTime, nil
}

func (m *manager) sendAPIRequest(ctx context.Context, url string, forceRenewal bool) ([]interface{}, error) {
err := m.renewAPIToken(ctx, forceRenewal)
if err != nil {
return nil, err
}

httpReq, err := http.NewRequest("GET", url, nil)
if err != nil {
return nil, err
}

// We don't need to take the lock when reading apiToken, because if we reach here,
// the token is valid at least for a couple of seconds. Even if another request modifies
// the token and expiration time while this request is in progress, the current token will still be valid.
httpReq.Header.Set("Authorization", "Bearer "+m.oidcToken.apiToken)

httpRes, err := m.client.Do(httpReq)
if err != nil {
return nil, err
}
defer httpRes.Body.Close()

if httpRes.StatusCode == http.StatusUnauthorized {
// The token is no longer valid, try renewing it
return m.sendAPIRequest(ctx, url, true)
}
if httpRes.StatusCode < 200 || httpRes.StatusCode > 299 {
return nil, errors.New("rest: API request returned " + httpRes.Status)
}

body, err := ioutil.ReadAll(httpRes.Body)
if err != nil {
return nil, err
}

var result map[string]interface{}
err = json.Unmarshal(body, &result)
if err != nil {
return nil, err
}

responseData, ok := result["data"].([]interface{})
if !ok {
return nil, errors.New("rest: error in type assertion")
}

return responseData, nil
}

func (m *manager) getGroupByParam(ctx context.Context, param, val string) (map[string]interface{}, error) {
url := fmt.Sprintf("%s/Group?filter=%s:%s&field=groupIdentifier&field=displayName&field=gid",
m.conf.APIBaseURL, param, val)
responseData, err := m.sendAPIRequest(ctx, url, false)
responseData, err := m.apiTokenManager.SendAPIGetRequest(ctx, url, false)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -369,7 +248,7 @@ func (m *manager) GetGroupByClaim(ctx context.Context, claim, value string) (*gr

func (m *manager) findGroupsByFilter(ctx context.Context, url string, groups map[string]*grouppb.Group) error {

groupData, err := m.sendAPIRequest(ctx, url, false)
groupData, err := m.apiTokenManager.SendAPIGetRequest(ctx, url, false)
if err != nil {
return err
}
Expand Down Expand Up @@ -440,7 +319,7 @@ func (m *manager) GetMembers(ctx context.Context, gid *grouppb.GroupId) ([]*user
return nil, err
}
url := fmt.Sprintf("%s/Group/%s/memberidentities/precomputed", m.conf.APIBaseURL, internalID)
userData, err := m.sendAPIRequest(ctx, url, false)
userData, err := m.apiTokenManager.SendAPIGetRequest(ctx, url, false)
if err != nil {
return nil, err
}
Expand Down
143 changes: 11 additions & 132 deletions pkg/cbox/user/rest/rest.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,21 +20,16 @@ package rest

import (
"context"
"encoding/json"
"errors"
"fmt"
"io/ioutil"
"net/http"
"net/url"
"regexp"
"strings"
"sync"
"time"

userpb "github.com/cs3org/go-cs3apis/cs3/identity/user/v1beta1"
types "github.com/cs3org/go-cs3apis/cs3/types/v1beta1"
"github.com/cs3org/reva/pkg/appctx"
"github.com/cs3org/reva/pkg/rhttp"
utils "github.com/cs3org/reva/pkg/cbox/utils"
"github.com/cs3org/reva/pkg/user"
"github.com/cs3org/reva/pkg/user/manager/registry"
"github.com/gomodule/redigo/redis"
Expand All @@ -51,17 +46,9 @@ var (
)

type manager struct {
conf *config
redisPool *redis.Pool
oidcToken OIDCToken
client *http.Client
}

// OIDCToken stores the OIDC token used to authenticate requests to the REST API service
type OIDCToken struct {
sync.Mutex // concurrent access to apiToken and tokenExpirationTime
apiToken string
tokenExpirationTime time.Time
conf *config
redisPool *redis.Pool
apiTokenManager *utils.APITokenManager
}

type config struct {
Expand Down Expand Up @@ -126,126 +113,18 @@ func New(m map[string]interface{}) (user.Manager, error) {
c.init()

redisPool := initRedisPool(c.RedisAddress, c.RedisUsername, c.RedisPassword)
apiTokenManager := utils.InitAPITokenManager(c.TargetAPI, c.OIDCTokenEndpoint, c.ClientID, c.ClientSecret)
return &manager{
conf: c,
redisPool: redisPool,
client: rhttp.GetHTTPClient(
rhttp.Timeout(10*time.Second),
rhttp.Insecure(true),
),
conf: c,
redisPool: redisPool,
apiTokenManager: apiTokenManager,
}, nil
}

func (m *manager) renewAPIToken(ctx context.Context, forceRenewal bool) error {
// Received tokens have an expiration time of 20 minutes.
// Take a couple of seconds as buffer time for the API call to complete
if forceRenewal || m.oidcToken.tokenExpirationTime.Before(time.Now().Add(time.Second*time.Duration(2))) {
token, expiration, err := m.getAPIToken(ctx)
if err != nil {
return err
}

m.oidcToken.Lock()
defer m.oidcToken.Unlock()

m.oidcToken.apiToken = token
m.oidcToken.tokenExpirationTime = expiration
}
return nil
}

func (m *manager) getAPIToken(ctx context.Context) (string, time.Time, error) {

params := url.Values{
"grant_type": {"client_credentials"},
"audience": {m.conf.TargetAPI},
}

httpReq, err := http.NewRequest("POST", m.conf.OIDCTokenEndpoint, strings.NewReader(params.Encode()))
if err != nil {
return "", time.Time{}, err
}
httpReq.SetBasicAuth(m.conf.ClientID, m.conf.ClientSecret)
httpReq.Header.Set("Content-Type", "application/x-www-form-urlencoded; param=value")

httpRes, err := m.client.Do(httpReq)
if err != nil {
return "", time.Time{}, err
}
defer httpRes.Body.Close()
if httpRes.StatusCode < 200 || httpRes.StatusCode > 299 {
return "", time.Time{}, errors.New("rest: get token endpoint returned " + httpRes.Status)
}

body, err := ioutil.ReadAll(httpRes.Body)
if err != nil {
return "", time.Time{}, err
}

var result map[string]interface{}
err = json.Unmarshal(body, &result)
if err != nil {
return "", time.Time{}, err
}

expirationSecs := result["expires_in"].(float64)
expirationTime := time.Now().Add(time.Second * time.Duration(expirationSecs))
return result["access_token"].(string), expirationTime, nil
}

func (m *manager) sendAPIRequest(ctx context.Context, url string, forceRenewal bool) ([]interface{}, error) {
err := m.renewAPIToken(ctx, forceRenewal)
if err != nil {
return nil, err
}

httpReq, err := http.NewRequest("GET", url, nil)
if err != nil {
return nil, err
}

// We don't need to take the lock when reading apiToken, because if we reach here,
// the token is valid at least for a couple of seconds. Even if another request modifies
// the token and expiration time while this request is in progress, the current token will still be valid.
httpReq.Header.Set("Authorization", "Bearer "+m.oidcToken.apiToken)

httpRes, err := m.client.Do(httpReq)
if err != nil {
return nil, err
}
defer httpRes.Body.Close()

if httpRes.StatusCode == http.StatusUnauthorized {
// The token is no longer valid, try renewing it
return m.sendAPIRequest(ctx, url, true)
}
if httpRes.StatusCode < 200 || httpRes.StatusCode > 299 {
return nil, errors.New("rest: API request returned " + httpRes.Status)
}

body, err := ioutil.ReadAll(httpRes.Body)
if err != nil {
return nil, err
}

var result map[string]interface{}
err = json.Unmarshal(body, &result)
if err != nil {
return nil, err
}

responseData, ok := result["data"].([]interface{})
if !ok {
return nil, errors.New("rest: error in type assertion")
}

return responseData, nil
}

func (m *manager) getUserByParam(ctx context.Context, param, val string) (map[string]interface{}, error) {
url := fmt.Sprintf("%s/Identity?filter=%s:%s&field=upn&field=primaryAccountEmail&field=displayName&field=uid&field=gid&field=type",
m.conf.APIBaseURL, param, val)
responseData, err := m.sendAPIRequest(ctx, url, false)
responseData, err := m.apiTokenManager.SendAPIGetRequest(ctx, url, false)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -381,7 +260,7 @@ func (m *manager) GetUserByClaim(ctx context.Context, claim, value string) (*use

func (m *manager) findUsersByFilter(ctx context.Context, url string, users map[string]*userpb.User) error {

userData, err := m.sendAPIRequest(ctx, url, false)
userData, err := m.apiTokenManager.SendAPIGetRequest(ctx, url, false)
if err != nil {
return err
}
Expand Down Expand Up @@ -466,7 +345,7 @@ func (m *manager) GetUserGroups(ctx context.Context, uid *userpb.UserId) ([]stri
return nil, err
}
url := fmt.Sprintf("%s/Identity/%s/groups", m.conf.APIBaseURL, internalID)
groupData, err := m.sendAPIRequest(ctx, url, false)
groupData, err := m.apiTokenManager.SendAPIGetRequest(ctx, url, false)
if err != nil {
return nil, err
}
Expand Down
Loading

0 comments on commit a6439ff

Please sign in to comment.