diff --git a/master/internal/api_token.go b/master/internal/api_token.go index b468b0a86c9..7b18e5ce097 100644 --- a/master/internal/api_token.go +++ b/master/internal/api_token.go @@ -19,19 +19,12 @@ import ( "github.com/determined-ai/determined/proto/pkg/apiv1" ) -var errAccessTokenRequiresEE = status.Error( - codes.FailedPrecondition, - "users cannot log in with an access token without a valid Enterprise Edition license set up.", -) - // PostAccessToken takes user id and optional lifespan, description and creates an // access token for the given user. func (a *apiServer) PostAccessToken( ctx context.Context, req *apiv1.PostAccessTokenRequest, ) (*apiv1.PostAccessTokenResponse, error) { - if !license.IsEE() { - return nil, errAccessTokenRequiresEE - } + license.RequireLicense("access tokens") curUser, _, err := grpcutil.GetUser(ctx) if err != nil { @@ -85,9 +78,7 @@ func (a *apiServer) PostAccessToken( func (a *apiServer) GetAccessTokens( ctx context.Context, req *apiv1.GetAccessTokensRequest, ) (*apiv1.GetAccessTokensResponse, error) { - if !license.IsEE() { - return nil, errAccessTokenRequiresEE - } + license.RequireLicense("access tokens") curUser, _, err := grpcutil.GetUser(ctx) if err != nil { @@ -192,9 +183,7 @@ func (a *apiServer) GetAccessTokens( func (a *apiServer) PatchAccessToken( ctx context.Context, req *apiv1.PatchAccessTokenRequest, ) (*apiv1.PatchAccessTokenResponse, error) { - if !license.IsEE() { - return nil, errAccessTokenRequiresEE - } + license.RequireLicense("access tokens") curUser, _, err := grpcutil.GetUser(ctx) if err != nil { diff --git a/master/internal/license/license.go b/master/internal/license/license.go index 0e3ca5492aa..9008447c5df 100644 --- a/master/internal/license/license.go +++ b/master/internal/license/license.go @@ -1,14 +1,5 @@ package license -import ( - "crypto/x509" - "encoding/base64" - "encoding/pem" - "fmt" - - "github.com/golang-jwt/jwt/v4" -) - const ( licenseRequiredMsg = "An enterprise license is required to use this feature" errCheckingLicense = "error when validating license" @@ -20,42 +11,8 @@ var licenseKey string // publicKey stores the public key used to verify licenses. Defaults to empty. var publicKey string -// decodedLicense contains the body of a decoded licenseKey. -type decodedLicense struct { - jwt.RegisteredClaims - - LicenseVersion string `json:"licenseVersion"` -} - -// RequireLicense panics if no licenseKey or an invalid licenseKey is used. -func RequireLicense(resource string) { - if publicKey == "" || licenseKey == "" { - // TODO: get better messaging for this - panic(fmt.Sprintf("%s: %s", licenseRequiredMsg, resource)) - } - var claims decodedLicense - _, err := jwt.ParseWithClaims(licenseKey, &claims, func(token *jwt.Token) (interface{}, error) { - pemData, err := base64.StdEncoding.DecodeString(publicKey) - if err != nil { - return nil, err - } - blk, _ := pem.Decode(pemData) - if blk == nil { - return nil, fmt.Errorf("error decoding pem") - } - key, err := x509.ParsePKIXPublicKey(blk.Bytes) - if err != nil { - return nil, fmt.Errorf("error parsing public key: %w", err) - } - return key, nil - }) - if err != nil { - panic(fmt.Sprintf("%s: %s", errCheckingLicense, err.Error())) - } - if claims.LicenseVersion != "1" { - panic("Specified licenseKey version is incompatible") - } -} +// RequireLicense is a no-op. +func RequireLicense(resource string) {} // IsEE returns true if a license is detected. func IsEE() bool {