diff --git a/src/k8s-extension/HISTORY.rst b/src/k8s-extension/HISTORY.rst index 42ad54552ac..d738bdcf21c 100644 --- a/src/k8s-extension/HISTORY.rst +++ b/src/k8s-extension/HISTORY.rst @@ -3,6 +3,10 @@ Release History =============== +1.2.2 +++++++++++++++++++ +* microsoft.azureml.kubernetes: disable service bus by default, do not create relay for managed clusters. + 1.2.1 ++++++++++++++++++ * Provide no default values for Patch of Extension diff --git a/src/k8s-extension/azext_k8s_extension/custom.py b/src/k8s-extension/azext_k8s_extension/custom.py index 49e392ddbae..a4aa858ee82 100644 --- a/src/k8s-extension/azext_k8s_extension/custom.py +++ b/src/k8s-extension/azext_k8s_extension/custom.py @@ -297,6 +297,7 @@ def update_k8s_extension( version, config_settings, config_protected_settings, + extension, yes, ) diff --git a/src/k8s-extension/azext_k8s_extension/partner_extensions/AzureMLKubernetes.py b/src/k8s-extension/azext_k8s_extension/partner_extensions/AzureMLKubernetes.py index 88d2595796b..97fab3fd8aa 100644 --- a/src/k8s-extension/azext_k8s_extension/partner_extensions/AzureMLKubernetes.py +++ b/src/k8s-extension/azext_k8s_extension/partner_extensions/AzureMLKubernetes.py @@ -57,6 +57,7 @@ def __init__(self): self.AZURE_LOG_ANALYTICS_CONNECTION_STRING = 'azure_log_analytics.connection_string' self.JOB_SCHEDULER_LOCATION_KEY = 'jobSchedulerLocation' self.CLUSTER_NAME_FRIENDLY_KEY = 'cluster_name_friendly' + self.NGINX_INGRESS_ENABLED_KEY = 'nginxIngress.enabled' # component flag self.ENABLE_TRAINING = 'enableTraining' @@ -66,6 +67,10 @@ def __init__(self): self.RELAY_SERVER_CONNECTION_STRING = 'relayServerConnectionString' # create relay connection string if None self.SERVICE_BUS_CONNECTION_STRING = 'serviceBusConnectionString' # create service bus if None self.LOG_ANALYTICS_WS_ENABLED = 'logAnalyticsWS' # create log analytics workspace if true + # default to false when creating the extension + self.SERVICE_BUS_ENABLED = 'servicebus.enabled' + # default to false if cluster is AKS when creating the extension + self.RELAY_SERVER_ENABLED = 'relayserver.enabled' # constants for azure resources creation self.RELAY_HC_AUTH_NAME = 'azureml_rw' @@ -106,9 +111,10 @@ def Create(self, cmd, client, resource_group_name, cluster_name, name, cluster_t configuration_settings_file, configuration_protected_settings_file): if scope == 'namespace': raise InvalidArgumentValueError("Invalid scope '{}'. This extension can't be installed " - "only at 'cluster' scope.".format(scope)) - if not release_namespace: - release_namespace = self.DEFAULT_RELEASE_NAMESPACE + "only at 'cluster' scope. " + "Check https://aka.ms/arcmltsg for more information.".format(scope)) + # set release name explicitly to azureml + release_namespace = self.DEFAULT_RELEASE_NAMESPACE scope_cluster = ScopeCluster(release_namespace=release_namespace) ext_scope = Scope(cluster=scope_cluster, namespace=None) @@ -136,6 +142,8 @@ def Create(self, cmd, client, resource_group_name, cluster_name, name, cluster_t nodeCount += agent['count'] if nodeCount < 3: configuration_settings['clusterPurpose'] = 'DevTest' + if resource.properties.get('distribution', '').lower() == self.OPEN_SHIFT: + configuration_settings[self.OPEN_SHIFT] = 'true' except: pass except CloudError as ex: @@ -150,6 +158,20 @@ def Create(self, cmd, client, resource_group_name, cluster_name, name, cluster_t self.JOB_SCHEDULER_LOCATION_KEY, cluster_location) configuration_settings[self.CLUSTER_NAME_FRIENDLY_KEY] = configuration_settings.get( self.CLUSTER_NAME_FRIENDLY_KEY, cluster_name) + # do not enable service bus by default + configuration_settings[self.SERVICE_BUS_ENABLED] = configuration_settings.get(self.SERVICE_BUS_ENABLED, 'false') + + # do not enable relay for managed cluster(AKS) by default, do not enable nginx for ARC by default + if cluster_type == "managedClusters": + configuration_settings[self.RELAY_SERVER_ENABLED] = configuration_settings.get(self.RELAY_SERVER_ENABLED, + 'false') + configuration_settings[self.NGINX_INGRESS_ENABLED_KEY] = configuration_settings.get( + self.NGINX_INGRESS_ENABLED_KEY, 'true') + else: + configuration_settings[self.RELAY_SERVER_ENABLED] = configuration_settings.get(self.RELAY_SERVER_ENABLED, + 'true') + configuration_settings[self.NGINX_INGRESS_ENABLED_KEY] = configuration_settings.get( + self.NGINX_INGRESS_ENABLED_KEY, 'false') # create Azure resources need by the extension based on the config. self.__create_required_resource( @@ -186,7 +208,14 @@ def Delete(self, cmd, client, resource_group_name, cluster_name, name, cluster_t user_confirmation_factory(cmd, yes) def Update(self, cmd, resource_group_name, cluster_name, auto_upgrade_minor_version, release_train, version, configuration_settings, - configuration_protected_settings, yes=False): + configuration_protected_settings, original_extension, yes=False): + input_configuration_settings = copy.deepcopy(configuration_settings) + input_configuration_protected_settings = copy.deepcopy(configuration_protected_settings) + # configuration_settings and configuration_protected_settings can be none, so need to set them to empty dict + if configuration_settings is None: + configuration_settings = {} + if configuration_protected_settings is None: + configuration_protected_settings = {} self.__normalize_config(configuration_settings, configuration_protected_settings) # Prompt message to ask customer to confirm again @@ -281,7 +310,11 @@ def Update(self, cmd, resource_group_name, cluster_name, auto_upgrade_minor_vers except azure.core.exceptions.HttpResponseError: logger.info("Failed to get log analytics connection string.") - if self.RELAY_SERVER_CONNECTION_STRING not in configuration_protected_settings: + original_extension_config_settings = original_extension.configuration_settings + if original_extension_config_settings is None: + original_extension_config_settings = {} + if original_extension_config_settings.get(self.RELAY_SERVER_ENABLED).lower() != 'false' \ + and self.RELAY_SERVER_CONNECTION_STRING not in configuration_protected_settings: try: relay_connection_string, _, _ = _get_relay_connection_str( cmd, subscription_id, resource_group_name, cluster_name, '', self.RELAY_HC_AUTH_NAME, True) @@ -289,10 +322,13 @@ def Update(self, cmd, resource_group_name, cluster_name, auto_upgrade_minor_vers logger.info("Get relay connection string succeeded.") except azure.mgmt.relay.models.ErrorResponseException as ex: if ex.response.status_code == 404: - raise ResourceNotFoundError("Relay server not found.") from ex - raise AzureResponseError("Failed to get relay connection string.") from ex + raise ResourceNotFoundError("Relay server not found. " + "Check https://aka.ms/arcmltsg for more information.") from ex + raise AzureResponseError("Failed to get relay connection string." + "Check https://aka.ms/arcmltsg for more information.") from ex - if self.SERVICE_BUS_CONNECTION_STRING not in configuration_protected_settings: + if original_extension_config_settings.get(self.SERVICE_BUS_ENABLED).lower() != 'false' \ + and self.SERVICE_BUS_CONNECTION_STRING not in configuration_protected_settings: try: service_bus_connection_string, _ = _get_service_bus_connection_string( cmd, subscription_id, resource_group_name, cluster_name, '', {}, True) @@ -300,8 +336,10 @@ def Update(self, cmd, resource_group_name, cluster_name, auto_upgrade_minor_vers logger.info("Get service bus connection string succeeded.") except azure.core.exceptions.HttpResponseError as ex: if ex.response.status_code == 404: - raise ResourceNotFoundError("Service bus not found.") from ex - raise AzureResponseError("Failed to get service bus connection string.") from ex + raise ResourceNotFoundError("Service bus not found." + "Check https://aka.ms/arcmltsg for more information.") from ex + raise AzureResponseError("Failed to get service bus connection string." + "Check https://aka.ms/arcmltsg for more information.") from ex configuration_protected_settings = _dereference(self.reference_mapping, configuration_protected_settings) @@ -314,6 +352,13 @@ def Update(self, cmd, resource_group_name, cluster_name, auto_upgrade_minor_vers if fe_ssl_cert_file and fe_ssl_key_file: self.__set_inference_ssl_from_file(configuration_protected_settings, fe_ssl_cert_file, fe_ssl_key_file) + # if no entries are existed in configuration_protected_settings, configuration_settings, return whatever passed + # in the Update function(empty dict or None). + if len(configuration_settings) == 0: + configuration_settings = input_configuration_settings + if len(configuration_protected_settings) == 0: + configuration_protected_settings = input_configuration_protected_settings + return PatchExtension(auto_upgrade_minor_version=auto_upgrade_minor_version, release_train=release_train, version=version, @@ -324,11 +369,12 @@ def __normalize_config(self, configuration_settings, configuration_protected_set # inference inferenceRouterHA = _get_value_from_config_protected_config( self.inferenceRouterHA, configuration_settings, configuration_protected_settings) - isTestCluster = True if inferenceRouterHA is not None and str(inferenceRouterHA).lower() == 'false' else False - if isTestCluster: - configuration_settings['clusterPurpose'] = 'DevTest' - else: - configuration_settings['clusterPurpose'] = 'FastProd' + if inferenceRouterHA is not None: + isTestCluster = str(inferenceRouterHA).lower() == 'false' + if isTestCluster: + configuration_settings['clusterPurpose'] = 'DevTest' + else: + configuration_settings['clusterPurpose'] = 'FastProd' inferenceRouterServiceType = _get_value_from_config_protected_config( self.inferenceRouterServiceType, configuration_settings, configuration_protected_settings) @@ -358,7 +404,8 @@ def __validate_config(self, configuration_settings, configuration_protected_sett for key in dup_keys: logger.warning( 'Duplicate keys found in both configuration settings and configuration protected setttings: %s', key) - raise InvalidArgumentValueError("Duplicate keys found.") + raise InvalidArgumentValueError("Duplicate keys found." + "Check https://aka.ms/arcmltsg for more information.") enable_training = _get_value_from_config_protected_config( self.ENABLE_TRAINING, configuration_settings, configuration_protected_settings) @@ -436,7 +483,8 @@ def __validate_scoring_fe_settings(self, configuration_settings, configuration_p if feIsNodePort and feIsInternalLoadBalancer: raise MutuallyExclusiveArgumentError( - "When using nodePort as inferenceRouterServiceType, no need to specify internalLoadBalancerProvider.") + "When using nodePort as inferenceRouterServiceType, no need to specify internalLoadBalancerProvider." + "Check https://aka.ms/arcmltsg for more information.") if feIsNodePort: configuration_settings['scoringFe.serviceType.nodePort'] = feIsNodePort elif feIsInternalLoadBalancer: @@ -493,7 +541,8 @@ def __create_required_resource( configuration_settings[self.AZURE_LOG_ANALYTICS_CUSTOMER_ID_KEY] = ws_costumer_id configuration_protected_settings[self.AZURE_LOG_ANALYTICS_CONNECTION_STRING] = shared_key - if not configuration_settings.get(self.RELAY_SERVER_CONNECTION_STRING) and \ + if str(configuration_settings.get(self.RELAY_SERVER_ENABLED)).lower() != 'false' and \ + not configuration_settings.get(self.RELAY_SERVER_CONNECTION_STRING) and \ not configuration_protected_settings.get(self.RELAY_SERVER_CONNECTION_STRING): logger.info('==== BEGIN RELAY CREATION ====') relay_connection_string, hc_resource_id, hc_name = _get_relay_connection_str( @@ -503,7 +552,8 @@ def __create_required_resource( configuration_settings[self.HC_RESOURCE_ID_KEY] = hc_resource_id configuration_settings[self.RELAY_HC_NAME_KEY] = hc_name - if not configuration_settings.get(self.SERVICE_BUS_CONNECTION_STRING) and \ + if str(configuration_settings.get(self.SERVICE_BUS_ENABLED)).lower() != 'false' and \ + not configuration_settings.get(self.SERVICE_BUS_CONNECTION_STRING) and \ not configuration_protected_settings.get(self.SERVICE_BUS_CONNECTION_STRING): logger.info('==== BEGIN SERVICE BUS CREATION ====') topic_sub_mapping = { @@ -675,9 +725,11 @@ def _dereference(ref_mapping_dict: Dict[str, List], output_dict: Dict[str, Any]) def _get_value_from_config_protected_config(key, config, protected_config): - if key in config: + if config is not None and key in config: return config[key] - return protected_config.get(key) + if protected_config is not None: + return protected_config.get(key) + return None def _check_nodeselector_existed(configuration_settings, configuration_protected_settings): diff --git a/src/k8s-extension/azext_k8s_extension/partner_extensions/DefaultExtension.py b/src/k8s-extension/azext_k8s_extension/partner_extensions/DefaultExtension.py index a15defc72d2..bae3fe4d1f0 100644 --- a/src/k8s-extension/azext_k8s_extension/partner_extensions/DefaultExtension.py +++ b/src/k8s-extension/azext_k8s_extension/partner_extensions/DefaultExtension.py @@ -71,6 +71,7 @@ def Update( version, configuration_settings, configuration_protected_settings, + original_extension: Extension, yes=False, ): """Default validations & defaults for Update diff --git a/src/k8s-extension/azext_k8s_extension/partner_extensions/PartnerExtensionModel.py b/src/k8s-extension/azext_k8s_extension/partner_extensions/PartnerExtensionModel.py index d4f1eeba6c3..c83429e2cfb 100644 --- a/src/k8s-extension/azext_k8s_extension/partner_extensions/PartnerExtensionModel.py +++ b/src/k8s-extension/azext_k8s_extension/partner_extensions/PartnerExtensionModel.py @@ -43,6 +43,7 @@ def Update( version: str, configuration_settings: dict, configuration_protected_settings: dict, + original_extension: Extension, yes: bool, ) -> PatchExtension: pass diff --git a/testing/test/extensions/public/AzureMLKubernetes.Tests.ps1 b/testing/test/extensions/public/AzureMLKubernetes.Tests.ps1 index 1b3b6bb662d..404e50dcd78 100644 --- a/testing/test/extensions/public/AzureMLKubernetes.Tests.ps1 +++ b/testing/test/extensions/public/AzureMLKubernetes.Tests.ps1 @@ -5,6 +5,8 @@ Describe 'AzureML Kubernetes Testing' { $extensionAgentNamespace = "azureml" $relayResourceIDKey = "relayserver.hybridConnectionResourceID" $serviceBusResourceIDKey = "servicebus.resourceID" + $mockUpdateKey = "mockTest" + $mockProtectedUpdateKey = "mockProtectedTest" . $PSScriptRoot/../../helper/Constants.ps1 . $PSScriptRoot/../../helper/Helper.ps1 @@ -13,7 +15,7 @@ Describe 'AzureML Kubernetes Testing' { It 'Creates the extension and checks that it onboards correctly with inference and SSL enabled' { $sslKeyPemFile = Join-Path (Join-Path (Join-Path (Split-Path $PSScriptRoot -Parent) "data") "azure_ml") "test_key.pem" $sslCertPemFile = Join-Path (Join-Path (Join-Path (Split-Path $PSScriptRoot -Parent) "data") "azure_ml") "test_cert.pem" - az $Env:K8sExtensionName create -c $($ENVCONFIG.arcClusterName) -g $($ENVCONFIG.resourceGroup) --cluster-type connectedClusters --extension-type $extensionType -n $extensionName --release-train staging --config enableInference=true identity.proxy.remoteEnabled=True identity.proxy.remoteHost=https://master.experiments.azureml-test.net inferenceRouterServiceType=nodePort sslCname=test.domain --config-protected sslKeyPemFile=$sslKeyPemFile sslCertPemFile=$sslCertPemFile --no-wait + az $Env:K8sExtensionName create -c $($ENVCONFIG.arcClusterName) -g $($ENVCONFIG.resourceGroup) --cluster-type connectedClusters --extension-type $extensionType -n $extensionName --release-train stable --config enableInference=true identity.proxy.remoteEnabled=True identity.proxy.remoteHost=https://master.experiments.azureml-test.net inferenceRouterServiceType=nodePort sslCname=test.domain --config-protected sslKeyPemFile=$sslKeyPemFile sslCertPemFile=$sslCertPemFile --no-wait $? | Should -BeTrue $output = az $Env:K8sExtensionName show -c $($ENVCONFIG.arcClusterName) -g $($ENVCONFIG.resourceGroup) --cluster-type connectedClusters -n $extensionName @@ -55,14 +57,57 @@ Describe 'AzureML Kubernetes Testing' { $extensionExists | Should -Not -BeNullOrEmpty } + It "Wait for the extension to be ready" { + # Loop and retry until the extension installed + $n = 0 + do + { + + $output = az $Env:K8sExtensionName show -c $($ENVCONFIG.arcClusterName) -g $($ENVCONFIG.resourceGroup) --cluster-type connectedClusters -n $extensionName + $? | Should -BeTrue + + $provisioningState = ($output | ConvertFrom-Json).provisioningState + Write-Host "Provisioning state: $provisioningState" + if ($provisioningState -eq "Succeeded") { + break + } + Start-Sleep -Seconds 20 + $n += 1 + } while ($n -le $MAX_RETRY_ATTEMPTS) + $n | Should -BeLessOrEqual $MAX_RETRY_ATTEMPTS + } + + It "Perform Update extension" { + az $Env:K8sExtensionName update -c $($ENVCONFIG.arcClusterName) -g $($ENVCONFIG.resourceGroup) --cluster-type connectedClusters -n $extensionName --config "$($mockUpdateKey)=true" --config-protected "$($mockProtectedUpdateKey)=true" --no-wait + $? | Should -BeTrue + + # Loop and retry until the extension updated + $n = 0 + do + { + + $output = az $Env:K8sExtensionName show -c $($ENVCONFIG.arcClusterName) -g $($ENVCONFIG.resourceGroup) --cluster-type connectedClusters -n $extensionName + $? | Should -BeTrue + + $provisioningState = ($output | ConvertFrom-Json).provisioningState + Write-Host "Provisioning state: $provisioningState" + if ($provisioningState -eq "Succeeded") { + break + } + Start-Sleep -Seconds 20 + $n += 1 + } while ($n -le $MAX_RETRY_ATTEMPTS) + $n | Should -BeLessOrEqual $MAX_RETRY_ATTEMPTS + + $mockedUpdateData = Get-ExtensionConfigurationSettings $extensionName $mockUpdateKey + $mockedUpdateData | Should -Not -BeNullOrEmpty + } + It "Deletes the extension from the cluster with inference enabled" { # cleanup the relay and servicebus $relayResourceID = Get-ExtensionConfigurationSettings $extensionName $relayResourceIDKey - $serviceBusResourceID = Get-ExtensionConfigurationSettings $extensionName $serviceBusResourceIDKey $relayNamespaceName = $relayResourceID.split("/")[8] - $serviceBusNamespaceName = $serviceBusResourceID.split("/")[8] az relay namespace delete --resource-group $ENVCONFIG.resourceGroup --name $relayNamespaceName - az servicebus namespace delete --resource-group $ENVCONFIG.resourceGroup --name $serviceBusNamespaceName $output = az $Env:K8sExtensionName delete -c $($ENVCONFIG.arcClusterName) -g $($ENVCONFIG.resourceGroup) --cluster-type connectedClusters -n $extensionName --force $? | Should -BeTrue