Skip to content

Commit

Permalink
Accept reva token as a bearer authentication (cs3org#3315)
Browse files Browse the repository at this point in the history
  • Loading branch information
gmgigi96 authored and aduffeck committed Oct 9, 2023
1 parent e751971 commit 55babd0
Show file tree
Hide file tree
Showing 4 changed files with 215 additions and 67 deletions.
3 changes: 3 additions & 0 deletions changelog/unreleased/reva-token-bearer.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
Enhancement: Accept reva token as a bearer authentication

https://github.com/cs3org/reva/pull/3315
194 changes: 127 additions & 67 deletions internal/http/interceptors/auth/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import (
"time"

"github.com/bluele/gcache"
authpb "github.com/cs3org/go-cs3apis/cs3/auth/provider/v1beta1"
gateway "github.com/cs3org/go-cs3apis/cs3/gateway/v1beta1"
userpb "github.com/cs3org/go-cs3apis/cs3/identity/user/v1beta1"
rpc "github.com/cs3org/go-cs3apis/cs3/rpc/v1beta1"
Expand Down Expand Up @@ -69,7 +70,7 @@ type config struct {
CredentialsByUserAgent map[string]string `mapstructure:"credentials_by_user_agent"`
CredentialChain []string `mapstructure:"credential_chain"`
CredentialStrategies map[string]map[string]interface{} `mapstructure:"credential_strategies"`
TokenStrategy string `mapstructure:"token_strategy"`
TokenStrategyChain []string `mapstructure:"token_strategy_chain"`
TokenStrategies map[string]map[string]interface{} `mapstructure:"token_strategies"`
TokenManager string `mapstructure:"token_manager"`
TokenManagers map[string]map[string]interface{} `mapstructure:"token_managers"`
Expand Down Expand Up @@ -97,8 +98,8 @@ func New(m map[string]interface{}, unprotected []string, tp trace.TracerProvider
conf.GatewaySvc = sharedconf.GetGatewaySVC(conf.GatewaySvc)

// set defaults
if conf.TokenStrategy == "" {
conf.TokenStrategy = "header"
if len(conf.TokenStrategyChain) == 0 {
conf.TokenStrategyChain = []string{"header"}
}

if conf.TokenWriter == "" {
Expand Down Expand Up @@ -139,19 +140,22 @@ func New(m map[string]interface{}, unprotected []string, tp trace.TracerProvider
credChain[key] = credStrategy
}

g, ok := tokenregistry.NewTokenFuncs[conf.TokenStrategy]
if !ok {
return nil, fmt.Errorf("token strategy not found: %s", conf.TokenStrategy)
}

tokenStrategy, err := g(conf.TokenStrategies[conf.TokenStrategy])
if err != nil {
return nil, err
tokenStrategyChain := make([]auth.TokenStrategy, 0, len(conf.TokenStrategyChain))
for _, strategy := range conf.TokenStrategyChain {
g, ok := tokenregistry.NewTokenFuncs[strategy]
if !ok {
return nil, fmt.Errorf("token strategy not found: %s", strategy)
}
tokenStrategy, err := g(conf.TokenStrategies[strategy])
if err != nil {
return nil, err
}
tokenStrategyChain = append(tokenStrategyChain, tokenStrategy)
}

h, ok := tokenmgr.NewFuncs[conf.TokenManager]
if !ok {
return nil, fmt.Errorf("token manager not found: %s", conf.TokenStrategy)
return nil, fmt.Errorf("token manager not found: %s", conf.TokenManager)
}

tokenManager, err := h(conf.TokenManagers[conf.TokenManager])
Expand Down Expand Up @@ -196,7 +200,7 @@ func New(m map[string]interface{}, unprotected []string, tp trace.TracerProvider
isUnprotectedEndpoint = true
}

ctx, err := authenticateUser(w, r, conf, tokenStrategy, tokenManager, tokenWriter, credChain, isUnprotectedEndpoint)
ctx, err := authenticateUser(w, r, conf, tokenStrategyChain, tokenManager, tokenWriter, credChain, isUnprotectedEndpoint)
if err != nil {
if !isUnprotectedEndpoint {
return
Expand All @@ -216,7 +220,7 @@ func New(m map[string]interface{}, unprotected []string, tp trace.TracerProvider
return chain, nil
}

func authenticateUser(w http.ResponseWriter, r *http.Request, conf *config, tokenStrategy auth.TokenStrategy, tokenManager token.Manager, tokenWriter auth.TokenWriter, credChain map[string]auth.CredentialStrategy, isUnprotectedEndpoint bool) (context.Context, error) {
func authenticateUser(w http.ResponseWriter, r *http.Request, conf *config, tokenStrategies []auth.TokenStrategy, tokenManager token.Manager, tokenWriter auth.TokenWriter, credChain map[string]auth.CredentialStrategy, isUnprotectedEndpoint bool) (context.Context, error) {
ctx := r.Context()
log := appctx.GetLogger(ctx)

Expand All @@ -229,71 +233,83 @@ func authenticateUser(w http.ResponseWriter, r *http.Request, conf *config, toke
return nil, err
}

tkn := tokenStrategy.GetToken(r)
if tkn == "" {
log.Warn().Msg("core access token not set")
// reva token or auth token can be passed using the same tecnique (for example bearer)
// before validating it against an auth provider, we can check directly if it's a reva
// token and if not try to use it for authenticating the user.
for _, tokenStrategy := range tokenStrategies {
token := tokenStrategy.GetToken(r)
if token != "" {
if user, tokenScope, ok := isTokenValid(r, tokenManager, token); ok {
if err := insertGroupsInUser(ctx, userGroupsCache, client, user); err != nil {
logError(isUnprotectedEndpoint, log, err, "got an error retrieving groups for user "+user.Username, http.StatusInternalServerError, w)
return nil, err
}
return ctxWithUserInfo(ctx, r, user, token, tokenScope), nil
}
}
}

userAgentCredKeys := getCredsForUserAgent(r.UserAgent(), conf.CredentialsByUserAgent, conf.CredentialChain)
log.Warn().Msg("core access token not set")

// obtain credentials (basic auth, bearer token, ...) based on user agent
var creds *auth.Credentials
for _, k := range userAgentCredKeys {
creds, err = credChain[k].GetCredentials(w, r)
if err != nil {
log.Debug().Err(err).Msg("error retrieving credentials")
}
userAgentCredKeys := getCredsForUserAgent(r.UserAgent(), conf.CredentialsByUserAgent, conf.CredentialChain)

if creds != nil {
log.Debug().Msgf("credentials obtained from credential strategy: type: %s, client_id: %s", creds.Type, creds.ClientID)
break
}
// obtain credentials (basic auth, bearer token, ...) based on user agent
var creds *auth.Credentials
for _, k := range userAgentCredKeys {
creds, err = credChain[k].GetCredentials(w, r)
if err != nil {
log.Debug().Err(err).Msg("error retrieving credentials")
}

// if no credentials are found, reply with authentication challenge depending on user agent
if creds == nil {
if !isUnprotectedEndpoint {
for _, key := range userAgentCredKeys {
if cred, ok := credChain[key]; ok {
cred.AddWWWAuthenticate(w, r, conf.Realm)
} else {
panic("auth credential strategy: " + key + "must have been loaded in init method")
}
}
w.WriteHeader(http.StatusUnauthorized)
}
return nil, errtypes.PermissionDenied("no credentials found")
if creds != nil {
log.Debug().Msgf("credentials obtained from credential strategy: type: %s, client_id: %s", creds.Type, creds.ClientID)
break
}
}

req := &gateway.AuthenticateRequest{
Type: creds.Type,
ClientId: creds.ClientID,
ClientSecret: creds.ClientSecret,
// if no credentials are found, reply with authentication challenge depending on user agent
if creds == nil {
if !isUnprotectedEndpoint {
for _, key := range userAgentCredKeys {
if cred, ok := credChain[key]; ok {
cred.AddWWWAuthenticate(w, r, conf.Realm)
} else {
panic("auth credential strategy: " + key + "must have been loaded in init method")
}
}
w.WriteHeader(http.StatusUnauthorized)
}
return nil, errtypes.PermissionDenied("no credentials found")
}

log.Debug().Msgf("AuthenticateRequest: type: %s, client_id: %s against %s", req.Type, req.ClientId, conf.GatewaySvc)
req := &gateway.AuthenticateRequest{
Type: creds.Type,
ClientId: creds.ClientID,
ClientSecret: creds.ClientSecret,
}

res, err := client.Authenticate(ctx, req)
if err != nil {
logError(isUnprotectedEndpoint, log, err, "error calling Authenticate", http.StatusUnauthorized, w)
return nil, err
}
log.Debug().Msgf("AuthenticateRequest: type: %s, client_id: %s against %s", req.Type, req.ClientId, conf.GatewaySvc)

if res.Status.Code != rpc.Code_CODE_OK {
err := status.NewErrorFromCode(res.Status.Code, "auth")
logError(isUnprotectedEndpoint, log, err, "error generating access token from credentials", http.StatusUnauthorized, w)
return nil, err
}
res, err := client.Authenticate(ctx, req)
if err != nil {
logError(isUnprotectedEndpoint, log, err, "error calling Authenticate", http.StatusUnauthorized, w)
return nil, err
}

log.Info().Msg("core access token generated")
// write token to response
tkn = res.Token
tokenWriter.WriteToken(tkn, w)
} else {
log.Debug().Msg("access token is already provided")
if res.Status.Code != rpc.Code_CODE_OK {
err := status.NewErrorFromCode(res.Status.Code, "auth")
logError(isUnprotectedEndpoint, log, err, "error generating access token from credentials", http.StatusUnauthorized, w)
return nil, err
}

log.Info().Msg("core access token generated") // write token to response

// write token to response
token := res.Token
tokenWriter.WriteToken(token, w)

// validate token
u, tokenScope, err := tokenManager.DismantleToken(r.Context(), tkn)
u, tokenScope, err := tokenManager.DismantleToken(r.Context(), token)
if err != nil {
logError(isUnprotectedEndpoint, log, err, "error dismantling token", http.StatusUnauthorized, w)
return nil, err
Expand Down Expand Up @@ -329,15 +345,59 @@ func authenticateUser(w http.ResponseWriter, r *http.Request, conf *config, toke

// store user and core access token in context.
ctx = ctxpkg.ContextSetUser(ctx, u)
ctx = ctxpkg.ContextSetToken(ctx, tkn)
ctx = metadata.AppendToOutgoingContext(ctx, ctxpkg.TokenHeader, tkn) // TODO(jfd): hardcoded metadata key. use PerRPCCredentials?
ctx = ctxpkg.ContextSetToken(ctx, token)
ctx = metadata.AppendToOutgoingContext(ctx, ctxpkg.TokenHeader, token) // TODO(jfd): hardcoded metadata key. use PerRPCCredentials?

ctx = metadata.AppendToOutgoingContext(ctx, ctxpkg.UserAgentHeader, r.UserAgent())

// store scopes in context
ctx = ctxpkg.ContextSetScopes(ctx, tokenScope)

return ctx, nil
return ctxWithUserInfo(ctx, r, u, token, tokenScope), nil
}

func ctxWithUserInfo(ctx context.Context, r *http.Request, user *userpb.User, token string, tokenScope map[string]*authpb.Scope) context.Context {
ctx = ctxpkg.ContextSetUser(ctx, user)
ctx = ctxpkg.ContextSetToken(ctx, token)
ctx = metadata.AppendToOutgoingContext(ctx, ctxpkg.TokenHeader, token)
ctx = metadata.AppendToOutgoingContext(ctx, ctxpkg.UserAgentHeader, r.UserAgent())
ctx = ctxpkg.ContextSetScopes(ctx, tokenScope)
return ctx
}

func insertGroupsInUser(ctx context.Context, userGroupsCache gcache.Cache, client gateway.GatewayAPIClient, user *userpb.User) error {
if sharedconf.SkipUserGroupsInToken() {
var groups []string
if groupsIf, err := userGroupsCache.Get(user.Id.OpaqueId); err == nil {
groups = groupsIf.([]string)
} else {
groupsRes, err := client.GetUserGroups(ctx, &userpb.GetUserGroupsRequest{UserId: user.Id})
if err != nil {
return err
}
groups = groupsRes.Groups
_ = userGroupsCache.SetWithExpire(user.Id.OpaqueId, groupsRes.Groups, 3600*time.Second)
}
user.Groups = groups
}
return nil
}

func isTokenValid(r *http.Request, tokenManager token.Manager, token string) (*userpb.User, map[string]*authpb.Scope, bool) {
ctx := r.Context()

u, tokenScope, err := tokenManager.DismantleToken(ctx, token)
if err != nil {
return nil, nil, false
}

// ensure access to the resource is allowed
ok, err := scope.VerifyScope(ctx, tokenScope, r.URL.Path)
if err != nil {
return nil, nil, false
}

return u, tokenScope, ok
}

func logError(isUnprotectedEndpoint bool, log *zerolog.Logger, err error, msg string, status int, w http.ResponseWriter) {
Expand Down
1 change: 1 addition & 0 deletions internal/http/interceptors/auth/token/loader/loader.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ package loader

import (
// Load core token strategies.
_ "github.com/cs3org/reva/v2/internal/http/interceptors/auth/token/strategy/bearer"
_ "github.com/cs3org/reva/v2/internal/http/interceptors/auth/token/strategy/header"
// Add your own here.
)
84 changes: 84 additions & 0 deletions internal/http/interceptors/auth/token/strategy/bearer/bearer.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
// Copyright 2018-2023 CERN
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// In applying this license, CERN does not waive the privileges and immunities
// granted to it by virtue of its status as an Intergovernmental Organization
// or submit itself to any jurisdiction.

package header

import (
"mime"
"net/http"
"strings"

"github.com/cs3org/reva/v2/internal/http/interceptors/auth/token/registry"
"github.com/cs3org/reva/v2/pkg/auth"
)

func init() {
registry.Register("bearer", New)
}

type b struct{}

// New returns a new auth strategy that checks for bearer auth.
func New(m map[string]interface{}) (auth.TokenStrategy, error) {
return b{}, nil
}

func (b) GetToken(r *http.Request) string {
// Authorization Request Header Field: https://www.rfc-editor.org/rfc/rfc6750#section-2.1
if tkn, ok := getFromAuthorizationHeader(r); ok {
return tkn
}

// Form-Encoded Body Parameter: https://www.rfc-editor.org/rfc/rfc6750#section-2.2
if tkn, ok := getFromBody(r); ok {
return tkn
}

// URI Query Parameter: https://www.rfc-editor.org/rfc/rfc6750#section-2.3
if tkn, ok := getFromQueryParam(r); ok {
return tkn
}

return ""
}

func getFromAuthorizationHeader(r *http.Request) (string, bool) {
auth := r.Header.Get("Authorization")
tkn := strings.TrimPrefix(auth, "Bearer ")
return tkn, tkn != ""
}

func getFromBody(r *http.Request) (string, bool) {
mediatype, _, err := mime.ParseMediaType(r.Header.Get("content-type"))
if err != nil {
return "", false
}
if mediatype != "application/x-www-form-urlencoded" {
return "", false
}
if err = r.ParseForm(); err != nil {
return "", false
}
tkn := r.Form.Get("access-token")
return tkn, tkn != ""
}

func getFromQueryParam(r *http.Request) (string, bool) {
tkn := r.URL.Query().Get("access_token")
return tkn, tkn != ""
}

0 comments on commit 55babd0

Please sign in to comment.