diff --git a/python/az/aro/azext_aro/_params.py b/python/az/aro/azext_aro/_params.py index 30652368e04..1f971b862e1 100644 --- a/python/az/aro/azext_aro/_params.py +++ b/python/az/aro/azext_aro/_params.py @@ -60,7 +60,7 @@ def load_arguments(self, _): c.argument('client_id', help='Client ID of cluster service principal.', - validator=validate_client_id) + validator=validate_client_id(isCreate=True)) c.argument('client_secret', help='Client secret of cluster service principal.', validator=validate_client_secret(isCreate=True)) @@ -148,6 +148,9 @@ def load_arguments(self, _): validator=validate_cluster_identity) with self.argument_context('aro update') as c: + c.argument('client_id', + help='Client ID of cluster service principal.', + validator=validate_client_id(isCreate=False)) c.argument('client_secret', help='Client secret of cluster service principal.', validator=validate_client_secret(isCreate=False)) diff --git a/python/az/aro/azext_aro/_validators.py b/python/az/aro/azext_aro/_validators.py index 983e24cabdc..13091fb0ddd 100644 --- a/python/az/aro/azext_aro/_validators.py +++ b/python/az/aro/azext_aro/_validators.py @@ -38,22 +38,24 @@ def _validate_cidr(namespace): return _validate_cidr -def validate_client_id(namespace): - if namespace.client_id is None: - return - if namespace.enable_managed_identity is True: - raise MutuallyExclusiveArgumentError('Must not specify --client-id when --enable-managed-identity is True') # pylint: disable=line-too-long - if namespace.platform_workload_identities is not None: - raise MutuallyExclusiveArgumentError('Must not specify --client-id when --assign-platform-workload-identity is used') # pylint: disable=line-too-long - try: - uuid.UUID(namespace.client_id) - except ValueError as e: - raise InvalidArgumentValueError(f"Invalid --client-id '{namespace.client_id}'.") from e # pylint: disable=line-too-long +def validate_client_id(isCreate): + def _validate_client_id(namespace): + if namespace.client_id is None: + return + if namespace.enable_managed_identity is True: + raise MutuallyExclusiveArgumentError('Must not specify --client-id when --enable-managed-identity is True') # pylint: disable=line-too-long + if namespace.platform_workload_identities is not None: + raise MutuallyExclusiveArgumentError('Must not specify --client-id when --assign-platform-workload-identity is used') # pylint: disable=line-too-long + try: + uuid.UUID(namespace.client_id) + except ValueError as e: + raise InvalidArgumentValueError(f"Invalid --client-id '{namespace.client_id}'.") from e # pylint: disable=line-too-long - if namespace.client_secret is None or not str(namespace.client_secret): - raise RequiredArgumentMissingError('Must specify --client-secret with --client-id.') # pylint: disable=line-too-long - if namespace.upgradeable_to is not None: - raise MutuallyExclusiveArgumentError('Must not specify --client-id when --upgradeable-to is used.') # pylint: disable=line-too-long + if namespace.client_secret is None or not str(namespace.client_secret): + raise RequiredArgumentMissingError('Must specify --client-secret with --client-id.') # pylint: disable=line-too-long + if not isCreate and namespace.upgradeable_to is not None: + raise MutuallyExclusiveArgumentError('Must not specify --client-id when --upgradeable-to is used.') # pylint: disable=line-too-long + return _validate_client_id def validate_client_secret(isCreate): @@ -66,7 +68,7 @@ def _validate_client_secret(namespace): raise MutuallyExclusiveArgumentError('Must not specify --client-secret when --assign-platform-workload-identity is used') # pylint: disable=line-too-long if isCreate and (namespace.client_id is None or not str(namespace.client_id)): raise RequiredArgumentMissingError('Must specify --client-id with --client-secret.') - if namespace.upgradeable_to is not None: + if not isCreate and namespace.upgradeable_to is not None: raise MutuallyExclusiveArgumentError('Must not specify --client-secret when --upgradeable-to is used.') # pylint: disable=line-too-long return _validate_client_secret diff --git a/python/az/aro/azext_aro/tests/latest/unit/test_validators.py b/python/az/aro/azext_aro/tests/latest/unit/test_validators.py index b252fae48d3..a0f3d16a53d 100644 --- a/python/az/aro/azext_aro/tests/latest/unit/test_validators.py +++ b/python/az/aro/azext_aro/tests/latest/unit/test_validators.py @@ -83,36 +83,43 @@ def test_validate_cidr(test_description, dummyclass, attribute_to_get_from_objec test_validate_client_id_data = [ ( "should not raise any Exception when namespace.client_id is None", + True, Mock(client_id=None), None ), ( "should raise MutuallyExclusiveArgumentError when enable_managed_identity is true", + True, Mock(client_id="12345678123456781234567812345678", enable_managed_identity=True), MutuallyExclusiveArgumentError ), ( "should raise MutuallyExclusiveArgumentError when platform_workload_identities is present", + True, Mock(client_id="12345678123456781234567812345678", platform_workload_identities=[("foo", Mock(resource_id='Foo'))]), MutuallyExclusiveArgumentError ), ( "should raise InvalidArgumentValueError when it can not create a UUID from namespace.client_id", + True, Mock(client_id="invalid_client_id", platform_workload_identities=None), InvalidArgumentValueError ), ( "should raise RequiredArgumentMissingError when can not create a string representation from namespace.client_secret because is None", + True, Mock(client_id="12345678123456781234567812345678", platform_workload_identities=None, client_secret=None), RequiredArgumentMissingError ), ( "should raise RequiredArgumentMissingError when can not create a string representation from namespace.client_secret because it is an empty string", + True, Mock(client_id="12345678123456781234567812345678", platform_workload_identities=None, client_secret=""), RequiredArgumentMissingError ), ( "should not raise any exception when namespace.client_id is a valid input for creating a UUID and namespace.client_secret has a valid str representation", + False, Mock(upgradeable_to=None, client_id="12345678123456781234567812345678", platform_workload_identities=None, client_secret="12345"), None ) @@ -120,16 +127,17 @@ def test_validate_cidr(test_description, dummyclass, attribute_to_get_from_objec @pytest.mark.parametrize( - "test_description, namespace, expected_exception", + "test_description, isCreate, namespace, expected_exception", test_validate_client_id_data, ids=[i[0] for i in test_validate_client_id_data] ) -def test_validate_client_id(test_description, namespace, expected_exception): +def test_validate_client_id(test_description, isCreate, namespace, expected_exception): + validate_client_id_fn = validate_client_id(isCreate) if expected_exception is None: - validate_client_id(namespace) + validate_client_id_fn(namespace) else: with pytest.raises(expected_exception): - validate_client_id(namespace) + validate_client_id_fn(namespace) test_validate_client_secret_data = [ @@ -181,12 +189,6 @@ def test_validate_client_id(test_description, namespace, expected_exception): Mock(upgradeable_to=None, client_secret="123", platform_workload_identities=None), None ), - ( - "should raise MutuallyExclusiveArgumentError exception when isCreate is true and upgradeable_to, client_id and client_secret are present", - True, - Mock(upgradeable_to="4.14.2", client_id="12345678123456781234567812345678", client_secret="123", platform_workload_identities=None), - MutuallyExclusiveArgumentError - ), ( "should raise MutuallyExclusiveArgumentError exception when isCreate is false and upgradeable_to, client_id and client_secret are present", False,