Skip to content

Commit

Permalink
Add operator license key check (#4925)
Browse files Browse the repository at this point in the history
Stops the operator on startup if the current operator license key is invalid.
  • Loading branch information
thbkrkr authored Oct 11, 2021
1 parent 995403b commit 138b0e1
Show file tree
Hide file tree
Showing 10 changed files with 200 additions and 98 deletions.
30 changes: 25 additions & 5 deletions cmd/manager/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ import (
"github.com/elastic/cloud-on-k8s/pkg/controller/beat"
"github.com/elastic/cloud-on-k8s/pkg/controller/common/certificates"
"github.com/elastic/cloud-on-k8s/pkg/controller/common/container"
commonlicense "github.com/elastic/cloud-on-k8s/pkg/controller/common/license"
"github.com/elastic/cloud-on-k8s/pkg/controller/common/operator"
"github.com/elastic/cloud-on-k8s/pkg/controller/common/reconciler"
controllerscheme "github.com/elastic/cloud-on-k8s/pkg/controller/common/scheme"
Expand Down Expand Up @@ -580,12 +581,31 @@ func startOperator(ctx context.Context) error {
"build_hash", operatorInfo.BuildInfo.Hash, "build_date", operatorInfo.BuildInfo.Date,
"build_snapshot", operatorInfo.BuildInfo.Snapshot)

if err := mgr.Start(ctx); err != nil {
log.Error(err, "Failed to start the controller manager")
return err
}
exitOnErr := make(chan error)

return nil
// start the manager
go func() {
if err := mgr.Start(ctx); err != nil {
log.Error(err, "Failed to start the controller manager")
exitOnErr <- err
}
}()

// check operator license key
go func() {
mgr.GetCache().WaitForCacheSync(ctx)

lc := commonlicense.NewLicenseChecker(mgr.GetClient(), params.OperatorNamespace)
licenseType, err := lc.ValidOperatorLicenseKeyType()
if err != nil {
log.Error(err, "Failed to validate operator license key")
exitOnErr <- err
} else {
log.Info("Operator license key validated", "license_type", licenseType)
}
}()

return <-exitOnErr
}

// asyncTasks schedules some tasks to be started when this instance of the operator is elected
Expand Down
32 changes: 8 additions & 24 deletions pkg/controller/autoscaling/elasticsearch/controller_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ func TestReconcile(t *testing.T) {
fields: fields{
EsClient: newFakeEsClient(t).withCapacity("frozen-tier"),
recorder: record.NewFakeRecorder(1000),
licenseChecker: &fakeLicenceChecker{},
licenseChecker: &license.MockLicenseChecker{EnterpriseEnabled: true},
},
args: args{
esManifest: "frozen-tier",
Expand All @@ -109,7 +109,7 @@ func TestReconcile(t *testing.T) {
fields: fields{
EsClient: newFakeEsClient(t).withCapacity("ml"),
recorder: record.NewFakeRecorder(1000),
licenseChecker: &fakeLicenceChecker{},
licenseChecker: &license.MockLicenseChecker{EnterpriseEnabled: true},
},
args: args{
esManifest: "ml",
Expand All @@ -124,7 +124,7 @@ func TestReconcile(t *testing.T) {
fields: fields{
EsClient: newFakeEsClient(t).withErrorOnDeleteAutoscalingAutoscalingPolicies(),
recorder: record.NewFakeRecorder(1000),
licenseChecker: &fakeLicenceChecker{},
licenseChecker: &license.MockLicenseChecker{EnterpriseEnabled: true},
},
args: args{
esManifest: "min-nodes-increased-by-user",
Expand All @@ -139,7 +139,7 @@ func TestReconcile(t *testing.T) {
fields: fields{
EsClient: newFakeEsClient(t).withCapacity("empty-autoscaling-api-response"),
recorder: record.NewFakeRecorder(1000),
licenseChecker: &fakeLicenceChecker{},
licenseChecker: &license.MockLicenseChecker{EnterpriseEnabled: true},
},
args: args{
esManifest: "empty-autoscaling-api-response",
Expand All @@ -152,7 +152,7 @@ func TestReconcile(t *testing.T) {
fields: fields{
EsClient: newFakeEsClient(t),
recorder: record.NewFakeRecorder(1000),
licenseChecker: &fakeLicenceChecker{},
licenseChecker: &license.MockLicenseChecker{EnterpriseEnabled: true},
},
args: args{
esManifest: "cluster-creation",
Expand All @@ -165,7 +165,7 @@ func TestReconcile(t *testing.T) {
fields: fields{
EsClient: newFakeEsClient(t).withCapacity("max-storage-reached"),
recorder: record.NewFakeRecorder(1000),
licenseChecker: &fakeLicenceChecker{},
licenseChecker: &license.MockLicenseChecker{EnterpriseEnabled: true},
},
args: args{
esManifest: "max-storage-reached",
Expand All @@ -182,7 +182,7 @@ func TestReconcile(t *testing.T) {
fields: fields{
EsClient: newFakeEsClient(t).withCapacity("storage-scaled-horizontally"),
recorder: record.NewFakeRecorder(1000),
licenseChecker: &fakeLicenceChecker{},
licenseChecker: &license.MockLicenseChecker{EnterpriseEnabled: true},
},
args: args{
esManifest: "storage-scaled-horizontally",
Expand All @@ -195,7 +195,7 @@ func TestReconcile(t *testing.T) {
fields: fields{
EsClient: newFakeEsClient(t),
recorder: record.NewFakeRecorder(1000),
licenseChecker: &fakeLicenceChecker{},
licenseChecker: &license.MockLicenseChecker{EnterpriseEnabled: true},
},
args: args{
esManifest: "",
Expand Down Expand Up @@ -374,19 +374,3 @@ func (f *fakeEsClient) GetAutoscalingCapacity(_ context.Context) (esclient.Autos
func (f *fakeEsClient) UpdateMLNodesSettings(_ context.Context, maxLazyMLNodes int32, maxMemory string) error {
return nil
}

// - Fake licence checker

type fakeLicenceChecker struct{}

func (flc *fakeLicenceChecker) CurrentEnterpriseLicense() (*license.EnterpriseLicense, error) {
return nil, nil
}

func (flc *fakeLicenceChecker) EnterpriseFeaturesEnabled() (bool, error) {
return true, nil
}

func (flc *fakeLicenceChecker) Valid(l license.EnterpriseLicense) (bool, error) {
return true, nil
}
38 changes: 29 additions & 9 deletions pkg/controller/common/license/check.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ package license

import (
"context"
"fmt"
"sort"
"time"

Expand All @@ -24,6 +25,7 @@ type Checker interface {
CurrentEnterpriseLicense() (*EnterpriseLicense, error)
EnterpriseFeaturesEnabled() (bool, error)
Valid(l EnterpriseLicense) (bool, error)
ValidOperatorLicenseKeyType() (OperatorLicenseType, error)
}

// checker contains parameters for license checks.
Expand Down Expand Up @@ -64,7 +66,7 @@ func (lc *checker) CurrentEnterpriseLicense() (*EnterpriseLicense, error) {
}

sort.Slice(licenses, func(i, j int) bool {
t1, t2 := EnterpriseLicenseTypeOrder[licenses[i].License.Type], EnterpriseLicenseTypeOrder[licenses[j].License.Type]
t1, t2 := OperatorLicenseTypeOrder[licenses[i].License.Type], OperatorLicenseTypeOrder[licenses[j].License.Type]
if t1 != t2 { // sort by type (first the most features)
return t1 > t2
}
Expand Down Expand Up @@ -115,20 +117,38 @@ func (lc *checker) Valid(l EnterpriseLicense) (bool, error) {
return false, nil
}

type MockChecker struct {
MissingLicense bool
// ValidOperatorLicenseKeyType returns true if the current operator license key is valid
func (lc checker) ValidOperatorLicenseKeyType() (OperatorLicenseType, error) {
lic, err := lc.CurrentEnterpriseLicense()
if err != nil {
log.V(-1).Info("Invalid Enterprise license, fallback to Basic: " + err.Error())
}

licType := lic.GetOperatorLicenseType()
if _, valid := OperatorLicenseTypeOrder[licType]; !valid {
return licType, fmt.Errorf("invalid license key: %s", licType)
}
return licType, nil
}

func (m MockChecker) CurrentEnterpriseLicense() (*EnterpriseLicense, error) {
type MockLicenseChecker struct {
EnterpriseEnabled bool
}

func (m MockLicenseChecker) CurrentEnterpriseLicense() (*EnterpriseLicense, error) {
return &EnterpriseLicense{}, nil
}

func (m MockChecker) EnterpriseFeaturesEnabled() (bool, error) {
return !m.MissingLicense, nil
func (m MockLicenseChecker) EnterpriseFeaturesEnabled() (bool, error) {
return m.EnterpriseEnabled, nil
}

func (m MockLicenseChecker) Valid(l EnterpriseLicense) (bool, error) {
return m.EnterpriseEnabled, nil
}

func (m MockChecker) Valid(l EnterpriseLicense) (bool, error) {
return !m.MissingLicense, nil
func (m MockLicenseChecker) ValidOperatorLicenseKeyType() (OperatorLicenseType, error) {
return LicenseTypeEnterprise, nil
}

var _ Checker = &MockChecker{}
var _ Checker = &MockLicenseChecker{}
99 changes: 99 additions & 0 deletions pkg/controller/common/license/check_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -236,3 +236,102 @@ func Test_CurrentEnterpriseLicense(t *testing.T) {
})
}
}

func Test_ValidOperatorLicenseKey(t *testing.T) {
privKey, err := x509.ParsePKCS1PrivateKey(privateKeyFixture)
require.NoError(t, err)

validLicenseFixture := licenseFixtureV3
validLicenseFixture.License.ExpiryDateInMillis = chrono.ToMillis(time.Now().Add(1 * time.Hour))
signatureBytes, err := NewSigner(privKey).Sign(validLicenseFixture)
require.NoError(t, err)
validLicense := asRuntimeObjects(validLicenseFixture, signatureBytes)

trialState, err := NewTrialState()
require.NoError(t, err)
validTrialLicenseFixture := emptyTrialLicenseFixture
require.NoError(t, trialState.InitTrialLicense(&validTrialLicenseFixture))
validTrialLicense := asRuntimeObject(validTrialLicenseFixture)

statusSecret, err := ExpectedTrialStatus(testNS, types.NamespacedName{}, trialState)
require.NoError(t, err)

type fields struct {
initialObjects []runtime.Object
operatorNamespace string
publicKey []byte
}

tests := []struct {
name string
fields fields
wantErr bool
wantType OperatorLicenseType
}{
{
name: "get valid basic license: OK",
fields: fields{
initialObjects: []runtime.Object{},
operatorNamespace: "test-system",
},
wantType: LicenseTypeBasic,
wantErr: false,
},
{
name: "get valid enterprise license: OK",
fields: fields{
initialObjects: validLicense,
operatorNamespace: "test-system",
publicKey: publicKeyBytesFixture(t),
},
wantType: LicenseTypeEnterprise,
wantErr: false,
},
{
name: "get valid trial enterprise license: OK",
fields: fields{
initialObjects: []runtime.Object{validTrialLicense, &statusSecret},
operatorNamespace: "test-system",
publicKey: publicKeyBytesFixture(t),
},
wantType: LicenseTypeEnterpriseTrial,
wantErr: false,
},
{
name: "get valid enterprise license among two licenses: OK",
fields: fields{
initialObjects: append(validLicense, validTrialLicense),
operatorNamespace: "test-system",
publicKey: publicKeyBytesFixture(t),
},
wantType: LicenseTypeEnterprise,
wantErr: false,
},
{
name: "invalid public key: fallback to basic",
fields: fields{
initialObjects: validLicense,
operatorNamespace: "test-system",
publicKey: []byte("not a public key"),
},
wantType: LicenseTypeBasic,
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
lc := &checker{
k8sClient: k8s.NewFakeClient(tt.fields.initialObjects...),
operatorNamespace: tt.fields.operatorNamespace,
publicKey: tt.fields.publicKey,
}
licenseType, err := lc.ValidOperatorLicenseKeyType()
if (err != nil) != tt.wantErr {
t.Errorf("Checker.ValidOperatorLicenseKeyType() err = %v, wantErr %v", err, tt.wantErr)
}
if licenseType != tt.wantType {
t.Errorf("Checker.ValidOperatorLicenseKeyType() licenseType = %v, wantType %v", licenseType, tt.wantType)
}
})
}
}
13 changes: 11 additions & 2 deletions pkg/controller/common/license/model.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import (
type OperatorLicenseType string

const (
LicenseTypeBasic OperatorLicenseType = "basic"
LicenseTypeEnterprise OperatorLicenseType = "enterprise"
LicenseTypeEnterpriseTrial OperatorLicenseType = "enterprise_trial"
// LicenseTypeLegacyTrial earlier versions of ECK used this as the trial identifier
Expand Down Expand Up @@ -47,8 +48,9 @@ type LicenseSpec struct {
Version int // not marshalled but part of the signature
}

// EnterpriseLicenseTypeOrder license types mapped to ints in increasing order of feature sets for sorting purposes.
var EnterpriseLicenseTypeOrder = map[OperatorLicenseType]int{
// OperatorLicenseTypeOrder license types mapped to ints in increasing order of feature sets for sorting purposes.
var OperatorLicenseTypeOrder = map[OperatorLicenseType]int{
LicenseTypeBasic: -1,
LicenseTypeLegacyTrial: 0,
LicenseTypeEnterpriseTrial: 1,
LicenseTypeEnterprise: 2,
Expand Down Expand Up @@ -107,6 +109,13 @@ func (l EnterpriseLicense) IsMissingFields() error {
return nil
}

func (l *EnterpriseLicense) GetOperatorLicenseType() OperatorLicenseType {
if l == nil {
return LicenseTypeBasic
}
return l.License.Type
}

// LicenseStatus expresses the validity status of a license.
type LicenseStatus string

Expand Down
Loading

0 comments on commit 138b0e1

Please sign in to comment.