Skip to content

Commit

Permalink
acl: RPC endpoints for JWT auth (#15918)
Browse files Browse the repository at this point in the history
  • Loading branch information
pkazmierczak committed Mar 16, 2023
1 parent 9c050e2 commit 519c746
Show file tree
Hide file tree
Showing 22 changed files with 877 additions and 73 deletions.
1 change: 1 addition & 0 deletions .semgrep/rpc_endpoint.yml
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ rules:
- pattern-not: 'structs.ACLListAuthMethodsRPCMethod'
- pattern-not: 'structs.ACLOIDCAuthURLRPCMethod'
- pattern-not: 'structs.ACLOIDCCompleteAuthRPCMethod'
- pattern-not: 'structs.ACLLoginRPCMethod'
- pattern-not: '"CSIPlugin.Get"'
- pattern-not: '"CSIPlugin.List"'
- pattern-not: '"Status.Leader"'
Expand Down
2 changes: 1 addition & 1 deletion command/agent/acl_endpoint.go
Original file line number Diff line number Diff line change
Expand Up @@ -867,7 +867,7 @@ func (s *HTTPServer) ACLOIDCCompleteAuthRequest(resp http.ResponseWriter, req *h
return nil, CodedError(http.StatusBadRequest, err.Error())
}

var out structs.ACLOIDCCompleteAuthResponse
var out structs.ACLLoginResponse
if err := s.agent.RPC(structs.ACLOIDCCompleteAuthRPCMethod, &args, &out); err != nil {
return nil, err
}
Expand Down
16 changes: 8 additions & 8 deletions command/agent/acl_endpoint_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1122,7 +1122,7 @@ func TestHTTPServer_ACLAuthMethodListRequest(t *testing.T) {

// Upsert two auth-methods into state.
must.NoError(t, srv.server.State().UpsertACLAuthMethods(
10, []*structs.ACLAuthMethod{mock.ACLAuthMethod(), mock.ACLAuthMethod()}))
10, []*structs.ACLAuthMethod{mock.ACLOIDCAuthMethod(), mock.ACLOIDCAuthMethod()}))

// Build the HTTP request.
req, err := http.NewRequest(http.MethodGet, "/v1/acl/auth-methods", nil)
Expand Down Expand Up @@ -1198,7 +1198,7 @@ func TestHTTPServer_ACLAuthMethodRequest(t *testing.T) {
testFn: func(srv *TestAgent) {

// Create a mock auth-method to use in the request body.
mockACLAuthMethod := mock.ACLAuthMethod()
mockACLAuthMethod := mock.ACLOIDCAuthMethod()

// Build the HTTP request.
req, err := http.NewRequest(http.MethodPut, "/v1/acl/auth-method", encodeReq(mockACLAuthMethod))
Expand Down Expand Up @@ -1269,7 +1269,7 @@ func TestHTTPServer_ACLAuthMethodSpecificRequest(t *testing.T) {
testFn: func(srv *TestAgent) {

// Create a mock auth-method and put directly into state.
mockACLAuthMethod := mock.ACLAuthMethod()
mockACLAuthMethod := mock.ACLOIDCAuthMethod()
must.NoError(t, srv.server.State().UpsertACLAuthMethods(
20, []*structs.ACLAuthMethod{mockACLAuthMethod}))

Expand All @@ -1294,7 +1294,7 @@ func TestHTTPServer_ACLAuthMethodSpecificRequest(t *testing.T) {
testFn: func(srv *TestAgent) {

// Create a mock auth-method and put directly into state.
mockACLAuthMethod := mock.ACLAuthMethod()
mockACLAuthMethod := mock.ACLOIDCAuthMethod()
must.NoError(t, srv.server.State().UpsertACLAuthMethods(
20, []*structs.ACLAuthMethod{mockACLAuthMethod}))

Expand Down Expand Up @@ -1499,7 +1499,7 @@ func TestHTTPServer_ACLBindingRuleRequest(t *testing.T) {

// Upsert the auth method that the binding rule will associate
// with.
mockACLAuthMethod := mock.ACLAuthMethod()
mockACLAuthMethod := mock.ACLOIDCAuthMethod()
must.NoError(t, srv.server.State().UpsertACLAuthMethods(
10, []*structs.ACLAuthMethod{mockACLAuthMethod}))

Expand Down Expand Up @@ -1607,7 +1607,7 @@ func TestHTTPServer_ACLBindingRuleSpecificRequest(t *testing.T) {

// Upsert the auth method that the binding rule will associate
// with.
mockACLAuthMethod := mock.ACLAuthMethod()
mockACLAuthMethod := mock.ACLOIDCAuthMethod()
must.NoError(t, srv.server.State().UpsertACLAuthMethods(
10, []*structs.ACLAuthMethod{mockACLAuthMethod}))

Expand Down Expand Up @@ -1716,7 +1716,7 @@ func TestHTTPServer_ACLOIDCAuthURLRequest(t *testing.T) {

// Generate and upsert an ACL auth method for use. Certain values must be
// taken from the cap OIDC provider just like real world use.
mockedAuthMethod := mock.ACLAuthMethod()
mockedAuthMethod := mock.ACLOIDCAuthMethod()
mockedAuthMethod.Config.AllowedRedirectURIs = []string{"http://127.0.0.1:4649/oidc/callback"}
mockedAuthMethod.Config.OIDCDiscoveryURL = oidcTestProvider.Addr()
mockedAuthMethod.Config.SigningAlgs = []string{"ES256"}
Expand Down Expand Up @@ -1799,7 +1799,7 @@ func TestHTTPServer_ACLOIDCCompleteAuthRequest(t *testing.T) {

// Generate and upsert an ACL auth method for use. Certain values must be
// taken from the cap OIDC provider just like real world use.
mockedAuthMethod := mock.ACLAuthMethod()
mockedAuthMethod := mock.ACLOIDCAuthMethod()
mockedAuthMethod.Config.BoundAudiences = []string{"mock"}
mockedAuthMethod.Config.AllowedRedirectURIs = []string{"http://127.0.0.1:4649/oidc/callback"}
mockedAuthMethod.Config.OIDCDiscoveryURL = oidcTestProvider.Addr()
Expand Down
3 changes: 1 addition & 2 deletions lib/auth/oidc/binder.go → lib/auth/binder.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package oidc
package auth

import (
"fmt"
Expand All @@ -8,7 +8,6 @@ import (
"github.com/hashicorp/go-memdb"
"github.com/hashicorp/hil"
"github.com/hashicorp/hil/ast"

"github.com/hashicorp/nomad/nomad/structs"
)

Expand Down
4 changes: 2 additions & 2 deletions lib/auth/oidc/binder_test.go → lib/auth/binder_test.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package oidc
package auth

import (
"testing"
Expand All @@ -19,7 +19,7 @@ func TestBinder_Bind(t *testing.T) {
testBind := NewBinder(testStore)

// create an authMethod method and insert into the state store
authMethod := mock.ACLAuthMethod()
authMethod := mock.ACLOIDCAuthMethod()
must.NoError(t, testStore.UpsertACLAuthMethods(0, []*structs.ACLAuthMethod{authMethod}))

// create some roles and insert into the state store
Expand Down
2 changes: 1 addition & 1 deletion lib/auth/oidc/claims.go → lib/auth/claims.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package oidc
package auth

import (
"encoding/json"
Expand Down
2 changes: 1 addition & 1 deletion lib/auth/oidc/claims_test.go → lib/auth/claims_test.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package oidc
package auth

import (
"testing"
Expand Down
2 changes: 1 addition & 1 deletion lib/auth/oidc/identity.go → lib/auth/identity.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package oidc
package auth

import (
"github.com/hashicorp/nomad/nomad/structs"
Expand Down
5 changes: 3 additions & 2 deletions lib/auth/oidc/identity_test.go → lib/auth/identity_test.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
package oidc
package auth

import (
"github.com/shoenig/test/must"
"testing"

"github.com/shoenig/test/must"

"github.com/hashicorp/nomad/ci"
"github.com/hashicorp/nomad/nomad/structs"
)
Expand Down
125 changes: 125 additions & 0 deletions lib/auth/jwt/validator.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
package jwt

import (
"context"
"crypto"
"fmt"
"time"

"github.com/armon/go-metrics"
"github.com/hashicorp/cap/jwt"
"golang.org/x/exp/slices"

"github.com/hashicorp/nomad/helper"
"github.com/hashicorp/nomad/nomad/structs"
)

// Validate performs token signature verification and JWT header validation,
// and returns a list of claims or an error in case any validation or signature
// verification fails.
func Validate(ctx context.Context, token string, methodConf *structs.ACLAuthMethodConfig) (map[string]any, error) {
var (
keySet jwt.KeySet
err error
)

// JWT validation can happen in 3 ways:
// - via embedded public keys, locally
// - via JWKS
// - or via OIDC provider
if len(methodConf.JWTValidationPubKeys) != 0 {
keySet, err = usingStaticKeys(methodConf.JWTValidationPubKeys)
if err != nil {
return nil, err
}
} else if methodConf.JWKSURL != "" {
keySet, err = usingJWKS(ctx, methodConf.JWKSURL, methodConf.JWKSCACert)
if err != nil {
return nil, err
}
} else if methodConf.OIDCDiscoveryURL != "" {
keySet, err = usingOIDC(ctx, methodConf.OIDCDiscoveryURL, methodConf.DiscoveryCaPem)
if err != nil {
return nil, err
}
}

// SigningAlgs field is a string, we need to convert it to a type the go-jwt
// accepts in order to validate.
toAlgFn := func(m string) jwt.Alg { return jwt.Alg(m) }
algorithms := helper.ConvertSlice(methodConf.SigningAlgs, toAlgFn)

expected := jwt.Expected{
Audiences: methodConf.BoundAudiences,
SigningAlgorithms: algorithms,
NotBeforeLeeway: methodConf.NotBeforeLeeway,
ExpirationLeeway: methodConf.ExpirationLeeway,
ClockSkewLeeway: methodConf.ClockSkewLeeway,
}

validator, err := jwt.NewValidator(keySet)
if err != nil {
return nil, err
}

claims, err := validator.Validate(ctx, token, expected)
if err != nil {
return nil, fmt.Errorf("unable to verify signature of JWT token: %v", err)
}

// validate issuer manually, because we allow users to specify an array
if len(methodConf.BoundIssuer) > 0 {
if _, ok := claims["iss"]; !ok {
return nil, fmt.Errorf(
"auth method specifies BoundIssuers but the provided token does not contain issuer information",
)
}
if iss, ok := claims["iss"].(string); !ok {
return nil, fmt.Errorf("unable to read iss property of provided token")
} else if !slices.Contains(methodConf.BoundIssuer, iss) {
return nil, fmt.Errorf("invalid JWT issuer: %v", claims["iss"])
}
}

return claims, nil
}

func usingStaticKeys(keys []string) (jwt.KeySet, error) {
var parsedKeys []crypto.PublicKey
for _, v := range keys {
key, err := jwt.ParsePublicKeyPEM([]byte(v))
parsedKeys = append(parsedKeys, key)
if err != nil {
return nil, fmt.Errorf("unable to parse public key for JWT auth: %v", err)
}
}
return jwt.NewStaticKeySet(parsedKeys)
}

func usingJWKS(ctx context.Context, jwksurl, jwkscapem string) (jwt.KeySet, error) {
// Measure the JWKS endpoint performance.
defer metrics.MeasureSince([]string{"nomad", "acl", "jwt", "jwks"}, time.Now())

keySet, err := jwt.NewJSONWebKeySet(ctx, jwksurl, jwkscapem)
if err != nil {
return nil, fmt.Errorf("unable to get validation keys from JWKS: %v", err)
}
return keySet, nil
}

func usingOIDC(ctx context.Context, oidcurl string, oidccapem []string) (jwt.KeySet, error) {
// Measure the OIDC endpoint performance.
defer metrics.MeasureSince([]string{"nomad", "acl", "jwt", "oidc_jwt"}, time.Now())

// TODO why do we have DiscoverCaPem as an array but JWKSCaPem as a single string?
pem := ""
if len(oidccapem) > 0 {
pem = oidccapem[0]
}

keySet, err := jwt.NewOIDCDiscoveryKeySet(ctx, oidcurl, pem)
if err != nil {
return nil, fmt.Errorf("unable to get validation keys from OIDC provider: %v", err)
}
return keySet, nil
}
Loading

0 comments on commit 519c746

Please sign in to comment.