diff --git a/api/v1/azureappconfigurationprovider_types.go b/api/v1/azureappconfigurationprovider_types.go index e5e1f42..c2b6fae 100644 --- a/api/v1/azureappconfigurationprovider_types.go +++ b/api/v1/azureappconfigurationprovider_types.go @@ -134,6 +134,7 @@ type AzureAppConfigurationProviderAuth struct { type WorkloadIdentityParameters struct { ManagedIdentityClientId *string `json:"managedIdentityClientId,omitempty"` ManagedIdentityClientIdReference *ManagedIdentityReferenceParameters `json:"managedIdentityClientIdReference,omitempty"` + ServiceAccountName *string `json:"serviceAccountName,omitempty"` } // ManagedIdentityReferenceParameters defines the parameters for configmap reference diff --git a/api/v1/zz_generated.deepcopy.go b/api/v1/zz_generated.deepcopy.go index 9ef0ba7..6e8d3ff 100644 --- a/api/v1/zz_generated.deepcopy.go +++ b/api/v1/zz_generated.deepcopy.go @@ -506,6 +506,11 @@ func (in *WorkloadIdentityParameters) DeepCopyInto(out *WorkloadIdentityParamete *out = new(ManagedIdentityReferenceParameters) **out = **in } + if in.ServiceAccountName != nil { + in, out := &in.ServiceAccountName, &out.ServiceAccountName + *out = new(string) + **out = **in + } } // DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new WorkloadIdentityParameters. diff --git a/config/crd/bases/azconfig.io_azureappconfigurationproviders.yaml b/config/crd/bases/azconfig.io_azureappconfigurationproviders.yaml index 3611000..6d67916 100644 --- a/config/crd/bases/azconfig.io_azureappconfigurationproviders.yaml +++ b/config/crd/bases/azconfig.io_azureappconfigurationproviders.yaml @@ -67,6 +67,8 @@ spec: - configMap - key type: object + serviceAccountName: + type: string type: object type: object configuration: @@ -208,6 +210,8 @@ spec: - configMap - key type: object + serviceAccountName: + type: string type: object required: - uri @@ -239,6 +243,8 @@ spec: - configMap - key type: object + serviceAccountName: + type: string type: object type: object refresh: diff --git a/deploy/parameter/helm-values.yaml b/deploy/parameter/helm-values.yaml index 30bdb21..1bc7291 100644 --- a/deploy/parameter/helm-values.yaml +++ b/deploy/parameter/helm-values.yaml @@ -17,6 +17,7 @@ fullnameOverride: "az-appconfig-k8s-provider" workloadIdentity: enabled: true + disableGlobalServiceAccount: false serviceAccount: # Specifies whether a service account should be created diff --git a/deploy/templates/_helpers.tpl b/deploy/templates/_helpers.tpl index 0f40c4e..088d547 100644 --- a/deploy/templates/_helpers.tpl +++ b/deploy/templates/_helpers.tpl @@ -49,7 +49,7 @@ Selector labels app.kubernetes.io/name: {{ include "az-appconfig-k8s-provider.name" . }} app.kubernetes.io/instance: {{ .Release.Name }} control-plane: controller-manager -{{- if eq .Values.workloadIdentity.enabled true }} +{{- if and (.Values.workloadIdentity.enabled) (not .Values.workloadIdentity.disableGlobalServiceAccount) }} azure.workload.identity/use: "true" {{- end }} {{- end }} diff --git a/deploy/templates/clusterrole.yaml b/deploy/templates/clusterrole.yaml index 8b9697c..64bffcb 100644 --- a/deploy/templates/clusterrole.yaml +++ b/deploy/templates/clusterrole.yaml @@ -53,6 +53,20 @@ rules: - patch - update - watch +{{- if .Values.workloadIdentity.enabled }} +- apiGroups: + - "" + resources: + - serviceaccounts + verbs: + - get +- apiGroups: + - "" + resources: + - serviceaccounts/token + verbs: + - create +{{- end }} --- apiVersion: rbac.authorization.k8s.io/v1 kind: ClusterRoleBinding diff --git a/deploy/templates/deployment.yaml b/deploy/templates/deployment.yaml index a5754a6..c2ba28a 100644 --- a/deploy/templates/deployment.yaml +++ b/deploy/templates/deployment.yaml @@ -51,6 +51,10 @@ spec: - name: AZURE_TENANT_ID value: {{ .Values.env.azureTenantId }} {{- end }} + - name: WORKLOAD_IDENTITY_ENABLED + value: "{{ .Values.workloadIdentity.enabled }}" + - name: WORKLOAD_IDENTITY_DISABLE_GLOBAL_SERVICE_ACCOUNT + value: "{{ .Values.workloadIdentity.disableGlobalServiceAccount }}" livenessProbe: httpGet: path: /healthz diff --git a/deploy/templates/serviceaccount.yaml b/deploy/templates/serviceaccount.yaml index cf835e9..c036964 100644 --- a/deploy/templates/serviceaccount.yaml +++ b/deploy/templates/serviceaccount.yaml @@ -9,7 +9,7 @@ metadata: {{- if .Values.serviceAccount.annotations }} {{ toYaml .Values.serviceAccount.annotations . | nindent 4 }} {{- end }} - {{- if eq .Values.workloadIdentity.enabled true }} + {{- if and (.Values.workloadIdentity.enabled) (not .Values.workloadIdentity.disableGlobalServiceAccount) }} azure.workload.identity/client-id: "" {{- end }} {{- end }} \ No newline at end of file diff --git a/internal/controller/appconfigurationprovider_controller_test.go b/internal/controller/appconfigurationprovider_controller_test.go index 48262fd..e49d61a 100644 --- a/internal/controller/appconfigurationprovider_controller_test.go +++ b/internal/controller/appconfigurationprovider_controller_test.go @@ -7,6 +7,7 @@ import ( "azappconfig/provider/internal/loader" "context" "fmt" + "os" "time" acpv1 "azappconfig/provider/api/v1" @@ -1059,9 +1060,12 @@ var _ = Describe("AppConfiguationProvider controller", func() { Context("Verify auth object", func() { It("Should return no error if auth object is valid", func() { + os.Setenv("WORKLOAD_IDENTITY_ENABLED", "true") + uuid1 := "86c613ca-b977-11ed-afa1-0242ac120002" secretName := "fakeName1" configMapName := "fakeName2" + serviceAccountName := "fakeName3" key := "fakeKey" authObj := &acpv1.AzureAppConfigurationProviderAuth{} authObj2 := &acpv1.AzureAppConfigurationProviderAuth{ @@ -1070,7 +1074,7 @@ var _ = Describe("AppConfiguationProvider controller", func() { authObj3 := &acpv1.AzureAppConfigurationProviderAuth{ ServicePrincipalReference: &secretName, } - autoObj4 := &acpv1.AzureAppConfigurationProviderAuth{ + authObj4 := &acpv1.AzureAppConfigurationProviderAuth{ WorkloadIdentity: &acpv1.WorkloadIdentityParameters{ ManagedIdentityClientId: &uuid1, }, @@ -1083,12 +1087,18 @@ var _ = Describe("AppConfiguationProvider controller", func() { }, }, } + authObj6 := &acpv1.AzureAppConfigurationProviderAuth{ + WorkloadIdentity: &acpv1.WorkloadIdentityParameters{ + ServiceAccountName: &serviceAccountName, + }, + } Expect(verifyAuthObject(nil)).Should(BeNil()) Expect(verifyAuthObject(authObj)).Should(BeNil()) Expect(verifyAuthObject(authObj2)).Should(BeNil()) Expect(verifyAuthObject(authObj3)).Should(BeNil()) - Expect(verifyAuthObject(autoObj4)).Should(BeNil()) + Expect(verifyAuthObject(authObj4)).Should(BeNil()) Expect(verifyAuthObject(authObj5)).Should(BeNil()) + Expect(verifyAuthObject(authObj6)).Should(BeNil()) }) It("Should return error if auth object is not valid", func() { @@ -1123,9 +1133,9 @@ var _ = Describe("AppConfiguationProvider controller", func() { } Expect(verifyAuthObject(authObj).Error()).Should(Equal("auth: ManagedIdentityClientId \"not-a-uuid\" in auth field is not a valid uuid")) Expect(verifyAuthObject(authObj2).Error()).Should(Equal("auth: more than one authentication methods are specified in 'auth' field")) - Expect(verifyAuthObject(authObj3).Error()).Should(Equal("auth.workloadIdentity: only one of managedIdentityClientId and managedIdentityClientIdReference is allowed")) + Expect(verifyAuthObject(authObj3).Error()).Should(Equal("auth.workloadIdentity: setting only one of 'managedIdentityClientId', 'managedIdentityClientIdReference' or 'serviceAccountName' field is allowed")) Expect(verifyAuthObject(authObj4).Error()).Should(Equal("auth.workloadIdentity.managedIdentityClientId: managedIdentityClientId \"not-a-uuid\" in auth.workloadIdentity is not a valid uuid")) - Expect(verifyAuthObject(authObj5).Error()).Should(Equal("auth.workloadIdentity: one of managedIdentityClientId and managedIdentityClientIdReference is required")) + Expect(verifyAuthObject(authObj5).Error()).Should(Equal("auth.workloadIdentity: setting one of 'managedIdentityClientId', 'managedIdentityClientIdReference' or 'serviceAccountName' field is required")) }) }) @@ -1214,8 +1224,8 @@ var _ = Describe("AppConfiguationProvider controller", func() { }, }, } - Expect(verifyExistingTargetObject(configMap1, configProvider.Spec.Target.ConfigMapName, configProvider.Name)).Should(MatchError("A ConfigMap with name 'configMapName' already exists in namespace 'default'")) - Expect(verifyExistingTargetObject(configMap2, configProvider.Spec.Target.ConfigMapName, configProvider.Name)).Should(MatchError("A ConfigMap with name 'configMapName' already exists in namespace 'default'")) + Expect(verifyExistingTargetObject(configMap1, configProvider.Spec.Target.ConfigMapName, configProvider.Name)).Should(MatchError("a ConfigMap with name 'configMapName' already exists in namespace 'default'")) + Expect(verifyExistingTargetObject(configMap2, configProvider.Spec.Target.ConfigMapName, configProvider.Name)).Should(MatchError("a ConfigMap with name 'configMapName' already exists in namespace 'default'")) }) }) }) diff --git a/internal/controller/utils.go b/internal/controller/utils.go index bb7c713..104f774 100644 --- a/internal/controller/utils.go +++ b/internal/controller/utils.go @@ -8,6 +8,7 @@ import ( "azappconfig/provider/internal/loader" "fmt" "net/url" + "os" "strings" "time" @@ -16,9 +17,11 @@ import ( ) const ( - MinimalSentinelBasedRefreshInterval time.Duration = time.Second - MinimalSecretRefreshInterval time.Duration = time.Minute - MinimalFeatureFlagRefreshInterval time.Duration = time.Second + MinimalSentinelBasedRefreshInterval time.Duration = time.Second + MinimalSecretRefreshInterval time.Duration = time.Minute + MinimalFeatureFlagRefreshInterval time.Duration = time.Second + WorkloadIdentityEnabled string = "WORKLOAD_IDENTITY_ENABLED" + WorkloadIdentityDisableGlobalServiceAccount string = "WORKLOAD_IDENTITY_DISABLE_GLOBAL_SERVICE_ACCOUNT" ) func verifyObject(spec acpv1.AzureAppConfigurationProviderSpec) error { @@ -204,7 +207,7 @@ func verifyExistingTargetObject[T client.Object](targetObj T, targetName string, } } - return fmt.Errorf("A %s with name '%s' already exists in namespace '%s'", objectKind, targetName, targetObj.GetNamespace()) + return fmt.Errorf("a %s with name '%s' already exists in namespace '%s'", objectKind, targetName, targetObj.GetNamespace()) } func hasNonEscapedValueInLabel(label string) bool { @@ -237,12 +240,36 @@ func verifyRefreshInterval(interval string, allowedMinimalRefreshInterval time.D } func verifyWorkloadIdentityParameters(workloadIdentity *acpv1.WorkloadIdentityParameters) error { - if workloadIdentity.ManagedIdentityClientId == nil && workloadIdentity.ManagedIdentityClientIdReference == nil { - return loader.NewArgumentError("auth.workloadIdentity", fmt.Errorf("one of managedIdentityClientId and managedIdentityClientIdReference is required")) + if !strings.EqualFold(os.Getenv(WorkloadIdentityEnabled), "true") { + return loader.NewArgumentError("auth.workloadIdentity", fmt.Errorf("workloadIdentity is not enabled")) } - if workloadIdentity.ManagedIdentityClientId != nil && workloadIdentity.ManagedIdentityClientIdReference != nil { - return loader.NewArgumentError("auth.workloadIdentity", fmt.Errorf("only one of managedIdentityClientId and managedIdentityClientIdReference is allowed")) + var authCount int = 0 + + if workloadIdentity.ManagedIdentityClientId != nil { + if strings.EqualFold(os.Getenv(WorkloadIdentityDisableGlobalServiceAccount), "true") { + return loader.NewArgumentError("auth.workloadIdentity.managedIdentityClientId", fmt.Errorf("'managedIdentityClientId' is not allowed since global service account is disabled")) + } + authCount++ + } + + if workloadIdentity.ManagedIdentityClientIdReference != nil { + if strings.EqualFold(os.Getenv(WorkloadIdentityDisableGlobalServiceAccount), "true") { + return loader.NewArgumentError("auth.workloadIdentity.managedIdentityClientIdReference", fmt.Errorf("'managedIdentityClientIdReference' is not allowed since global service account is disabled")) + } + authCount++ + } + + if workloadIdentity.ServiceAccountName != nil { + authCount++ + } + + if authCount == 0 { + return loader.NewArgumentError("auth.workloadIdentity", fmt.Errorf("setting one of 'managedIdentityClientId', 'managedIdentityClientIdReference' or 'serviceAccountName' field is required")) + } + + if authCount > 1 { + return loader.NewArgumentError("auth.workloadIdentity", fmt.Errorf("setting only one of 'managedIdentityClientId', 'managedIdentityClientIdReference' or 'serviceAccountName' field is allowed")) } if workloadIdentity.ManagedIdentityClientId != nil { diff --git a/internal/loader/configuration_client_manager.go b/internal/loader/configuration_client_manager.go index 9ac741d..dadc7a1 100644 --- a/internal/loader/configuration_client_manager.go +++ b/internal/loader/configuration_client_manager.go @@ -8,24 +8,29 @@ import ( "context" "fmt" "math" + "math/rand" "net" "net/url" + "os" "strconv" "strings" "time" acpv1 "azappconfig/provider/api/v1" - "math/rand" - "github.com/Azure/azure-sdk-for-go/sdk/azcore" "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" "github.com/Azure/azure-sdk-for-go/sdk/azidentity" "github.com/Azure/azure-sdk-for-go/sdk/data/azappconfig" "github.com/google/uuid" + authv1 "k8s.io/api/authentication/v1" + corev1 "k8s.io/api/core/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/types" + "k8s.io/client-go/kubernetes" "k8s.io/klog/v2" + "sigs.k8s.io/controller-runtime/pkg/client" + ctrlcfg "sigs.k8s.io/controller-runtime/pkg/client/config" ) //go:generate mockgen -destination=mocks/mock_configuration_client_manager.go -package mocks . ClientManager @@ -70,6 +75,9 @@ const ( MinBackoffDuration time.Duration = time.Second * 30 JitterRatio float64 = 0.25 SafeShiftLimit int = 63 + ApiTokenExchangeAudience string = "api://AzureADTokenExchange" + AnnotationClientID string = "azure.workload.identity/client-id" + AnnotationTenantID string = "azure.workload.identity/tenant-id" ) var ( @@ -191,10 +199,10 @@ func (manager *ConfigurationClientManager) DiscoverFallbackClients(ctx context.C select { case <-newCtx.Done(): - klog.Warningf("fail to build fall back clients, SRV DNS lookup is timeout") + klog.Warningf("fail to build fallback clients, SRV DNS lookup is timeout") break case err := <-errChan: - klog.Warningf("fail to build fall back clients %s", err.Error()) + klog.Warningf("fail to build fallback clients %s", err.Error()) break case srvTargetHosts := <-resultChan: // Shuffle the list of SRV target hosts @@ -362,10 +370,15 @@ func CreateTokenCredential(ctx context.Context, acpAuth *acpv1.AzureAppConfigura // If User explicitly specify the authentication method if acpAuth != nil { if acpAuth.WorkloadIdentity != nil { + if acpAuth.WorkloadIdentity.ServiceAccountName != nil { + return newClientAssertionCredential(ctx, *acpAuth.WorkloadIdentity.ServiceAccountName, namespace) + } + workloadIdentityClientId, err := getWorkloadIdentityClientId(ctx, acpAuth.WorkloadIdentity, namespace) if err != nil { return nil, fmt.Errorf("fail to retrieve workload identity client ID from configMap '%s' : %s", acpAuth.WorkloadIdentity.ManagedIdentityClientIdReference.ConfigMap, err.Error()) } + return azidentity.NewWorkloadIdentityCredential(&azidentity.WorkloadIdentityCredentialOptions{ ClientID: workloadIdentityClientId, }) @@ -375,6 +388,7 @@ func CreateTokenCredential(ctx context.Context, acpAuth *acpv1.AzureAppConfigura if err != nil { return nil, fmt.Errorf("fail to retrieve service principal secret from '%s': %s", *acpAuth.ServicePrincipalReference, err.Error()) } + return azidentity.NewClientSecretCredential(parameter.TenantId, parameter.ClientId, parameter.ClientSecret, nil) } if acpAuth.ManagedIdentityClientId != nil { @@ -461,3 +475,71 @@ func Jitter(duration time.Duration) time.Duration { // Apply the random jitter to the original duration return duration + time.Duration(randomJitter) } + +func newClientAssertionCredential(ctx context.Context, serviceAccountName string, serviceAccountNamespace string) (azcore.TokenCredential, error) { + cfg, err := ctrlcfg.GetConfig() + if err != nil { + return nil, err + } + + client, err := client.New(cfg, client.Options{}) + if err != nil { + return nil, err + } + + serviceAccountObj := &corev1.ServiceAccount{} + err = client.Get(ctx, types.NamespacedName{Namespace: serviceAccountNamespace, Name: serviceAccountName}, serviceAccountObj) + if err != nil { + return nil, err + } + + if _, ok := serviceAccountObj.Annotations[AnnotationClientID]; !ok { + return nil, fmt.Errorf("annotation '%s' of service account %s/%s is required", AnnotationClientID, serviceAccountNamespace, serviceAccountName) + } + + tenantId := "" + + if _, ok := serviceAccountObj.Annotations[AnnotationTenantID]; ok { + tenantId = serviceAccountObj.Annotations[AnnotationTenantID] + } else if _, ok := os.LookupEnv(strings.ToUpper(AzureTenantId)); ok { + tenantId = os.Getenv(strings.ToUpper(AzureTenantId)) + } else { + return nil, fmt.Errorf("annotation '%s' of service account %s/%s is required since using global service account for workload identity is disabled", AnnotationTenantID, serviceAccountNamespace, serviceAccountName) + } + + getAssertionFunc := newGetAssertionFunc(serviceAccountNamespace, serviceAccountName) + + clientAssertionCredential, err := azidentity.NewClientAssertionCredential(tenantId, serviceAccountObj.Annotations[AnnotationClientID], getAssertionFunc, nil) + if err != nil { + return nil, err + } + + return clientAssertionCredential, nil +} + +func newGetAssertionFunc(serviceAccountNamespace string, serviceAccountName string) func(ctx context.Context) (string, error) { + audiences := []string{ApiTokenExchangeAudience} + + return func(ctx context.Context) (string, error) { + cfg, err := ctrlcfg.GetConfig() + if err != nil { + return "", err + } + + kubeClient, err := kubernetes.NewForConfig(cfg) + if err != nil { + return "", err + } + + token, err := kubeClient.CoreV1().ServiceAccounts(serviceAccountNamespace).CreateToken(ctx, serviceAccountName, &authv1.TokenRequest{ + Spec: authv1.TokenRequestSpec{ + Audiences: audiences, + }, + }, metav1.CreateOptions{}) + if err != nil { + return "", err + } + + return token.Status.Token, nil + } +}