From c434c6b98911d71f27910d7d30cf966462157612 Mon Sep 17 00:00:00 2001 From: Maksim Date: Wed, 24 Apr 2024 14:25:39 +0200 Subject: [PATCH] Create `CloudComposerRunAirflowCLICommandOperator` operator (#38965) --- .../google/cloud/hooks/cloud_composer.py | 252 +++++++++++++++++- .../google/cloud/operators/cloud_composer.py | 149 ++++++++++- .../google/cloud/triggers/cloud_composer.py | 68 +++++ .../operators/cloud/cloud_composer.rst | 20 ++ tests/always/test_project_structure.py | 1 - .../google/cloud/hooks/test_cloud_composer.py | 128 +++++++++ .../cloud/operators/test_cloud_composer.py | 66 ++++- .../cloud/triggers/test_cloud_composer.py | 96 +++++++ .../cloud/composer/example_cloud_composer.py | 25 ++ 9 files changed, 799 insertions(+), 6 deletions(-) create mode 100644 tests/providers/google/cloud/triggers/test_cloud_composer.py diff --git a/airflow/providers/google/cloud/hooks/cloud_composer.py b/airflow/providers/google/cloud/hooks/cloud_composer.py index b5421e6b2b2f7..4ee41c9ffd626 100644 --- a/airflow/providers/google/cloud/hooks/cloud_composer.py +++ b/airflow/providers/google/cloud/hooks/cloud_composer.py @@ -17,7 +17,9 @@ # under the License. from __future__ import annotations -from typing import TYPE_CHECKING, Sequence +import asyncio +import time +from typing import TYPE_CHECKING, MutableSequence, Sequence from google.api_core.client_options import ClientOptions from google.api_core.gapic_v1.method import DEFAULT, _MethodDefault @@ -25,6 +27,7 @@ EnvironmentsAsyncClient, EnvironmentsClient, ImageVersionsClient, + PollAirflowCommandResponse, ) from airflow.exceptions import AirflowException @@ -42,7 +45,10 @@ from google.cloud.orchestration.airflow.service_v1.services.image_versions.pagers import ( ListImageVersionsPager, ) - from google.cloud.orchestration.airflow.service_v1.types import Environment + from google.cloud.orchestration.airflow.service_v1.types import ( + Environment, + ExecuteAirflowCommandResponse, + ) from google.protobuf.field_mask_pb2 import FieldMask @@ -294,6 +300,127 @@ def list_image_versions( ) return result + @GoogleBaseHook.fallback_to_default_project_id + def execute_airflow_command( + self, + project_id: str, + region: str, + environment_id: str, + command: str, + subcommand: str, + parameters: MutableSequence[str], + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), + ) -> ExecuteAirflowCommandResponse: + """ + Execute Airflow command for provided Composer environment. + + :param project_id: The ID of the Google Cloud project that the service belongs to. + :param region: The ID of the Google Cloud region that the service belongs to. + :param environment_id: The ID of the Google Cloud environment that the service belongs to. + :param command: Airflow command. + :param subcommand: Airflow subcommand. + :param parameters: Parameters for the Airflow command/subcommand as an array of arguments. It may + contain positional arguments like ``["my-dag-id"]``, key-value parameters like ``["--foo=bar"]`` + or ``["--foo","bar"]``, or other flags like ``["-f"]``. + :param retry: Designation of what errors, if any, should be retried. + :param timeout: The timeout for this request. + :param metadata: Strings which should be sent along with the request as metadata. + """ + client = self.get_environment_client() + result = client.execute_airflow_command( + request={ + "environment": self.get_environment_name(project_id, region, environment_id), + "command": command, + "subcommand": subcommand, + "parameters": parameters, + }, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + return result + + @GoogleBaseHook.fallback_to_default_project_id + def poll_airflow_command( + self, + project_id: str, + region: str, + environment_id: str, + execution_id: str, + pod: str, + pod_namespace: str, + next_line_number: int, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), + ) -> PollAirflowCommandResponse: + """ + Poll Airflow command execution result for provided Composer environment. + + :param project_id: The ID of the Google Cloud project that the service belongs to. + :param region: The ID of the Google Cloud region that the service belongs to. + :param environment_id: The ID of the Google Cloud environment that the service belongs to. + :param execution_id: The unique ID of the command execution. + :param pod: The name of the pod where the command is executed. + :param pod_namespace: The namespace of the pod where the command is executed. + :param next_line_number: Line number from which new logs should be fetched. + :param retry: Designation of what errors, if any, should be retried. + :param timeout: The timeout for this request. + :param metadata: Strings which should be sent along with the request as metadata. + """ + client = self.get_environment_client() + result = client.poll_airflow_command( + request={ + "environment": self.get_environment_name(project_id, region, environment_id), + "execution_id": execution_id, + "pod": pod, + "pod_namespace": pod_namespace, + "next_line_number": next_line_number, + }, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + return result + + def wait_command_execution_result( + self, + project_id: str, + region: str, + environment_id: str, + execution_cmd_info: dict, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), + poll_interval: int = 10, + ) -> dict: + while True: + try: + result = self.poll_airflow_command( + project_id=project_id, + region=region, + environment_id=environment_id, + execution_id=execution_cmd_info["execution_id"], + pod=execution_cmd_info["pod"], + pod_namespace=execution_cmd_info["pod_namespace"], + next_line_number=1, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + except Exception as ex: + self.log.exception("Exception occurred while polling CMD result") + raise AirflowException(ex) + + result_dict = PollAirflowCommandResponse.to_dict(result) + if result_dict["output_end"]: + return result_dict + + self.log.info("Waiting for result...") + time.sleep(poll_interval) + class CloudComposerAsyncHook(GoogleBaseHook): """Hook for Google Cloud Composer async APIs.""" @@ -421,3 +548,124 @@ async def update_environment( timeout=timeout, metadata=metadata, ) + + @GoogleBaseHook.fallback_to_default_project_id + async def execute_airflow_command( + self, + project_id: str, + region: str, + environment_id: str, + command: str, + subcommand: str, + parameters: MutableSequence[str], + retry: AsyncRetry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), + ) -> AsyncOperation: + """ + Execute Airflow command for provided Composer environment. + + :param project_id: The ID of the Google Cloud project that the service belongs to. + :param region: The ID of the Google Cloud region that the service belongs to. + :param environment_id: The ID of the Google Cloud environment that the service belongs to. + :param command: Airflow command. + :param subcommand: Airflow subcommand. + :param parameters: Parameters for the Airflow command/subcommand as an array of arguments. It may + contain positional arguments like ``["my-dag-id"]``, key-value parameters like ``["--foo=bar"]`` + or ``["--foo","bar"]``, or other flags like ``["-f"]``. + :param retry: Designation of what errors, if any, should be retried. + :param timeout: The timeout for this request. + :param metadata: Strings which should be sent along with the request as metadata. + """ + client = self.get_environment_client() + + return await client.execute_airflow_command( + request={ + "environment": self.get_environment_name(project_id, region, environment_id), + "command": command, + "subcommand": subcommand, + "parameters": parameters, + }, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + @GoogleBaseHook.fallback_to_default_project_id + async def poll_airflow_command( + self, + project_id: str, + region: str, + environment_id: str, + execution_id: str, + pod: str, + pod_namespace: str, + next_line_number: int, + retry: AsyncRetry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), + ) -> AsyncOperation: + """ + Poll Airflow command execution result for provided Composer environment. + + :param project_id: The ID of the Google Cloud project that the service belongs to. + :param region: The ID of the Google Cloud region that the service belongs to. + :param environment_id: The ID of the Google Cloud environment that the service belongs to. + :param execution_id: The unique ID of the command execution. + :param pod: The name of the pod where the command is executed. + :param pod_namespace: The namespace of the pod where the command is executed. + :param next_line_number: Line number from which new logs should be fetched. + :param retry: Designation of what errors, if any, should be retried. + :param timeout: The timeout for this request. + :param metadata: Strings which should be sent along with the request as metadata. + """ + client = self.get_environment_client() + + return await client.poll_airflow_command( + request={ + "environment": self.get_environment_name(project_id, region, environment_id), + "execution_id": execution_id, + "pod": pod, + "pod_namespace": pod_namespace, + "next_line_number": next_line_number, + }, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + async def wait_command_execution_result( + self, + project_id: str, + region: str, + environment_id: str, + execution_cmd_info: dict, + retry: AsyncRetry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), + poll_interval: int = 10, + ) -> dict: + while True: + try: + result = await self.poll_airflow_command( + project_id=project_id, + region=region, + environment_id=environment_id, + execution_id=execution_cmd_info["execution_id"], + pod=execution_cmd_info["pod"], + pod_namespace=execution_cmd_info["pod_namespace"], + next_line_number=1, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + except Exception as ex: + self.log.exception("Exception occurred while polling CMD result") + raise AirflowException(ex) + + result_dict = PollAirflowCommandResponse.to_dict(result) + if result_dict["output_end"]: + return result_dict + + self.log.info("Sleeping for %s seconds.", poll_interval) + await asyncio.sleep(poll_interval) diff --git a/airflow/providers/google/cloud/operators/cloud_composer.py b/airflow/providers/google/cloud/operators/cloud_composer.py index de6d49d6566cc..ad63397a176b1 100644 --- a/airflow/providers/google/cloud/operators/cloud_composer.py +++ b/airflow/providers/google/cloud/operators/cloud_composer.py @@ -17,19 +17,23 @@ # under the License. from __future__ import annotations +import shlex from typing import TYPE_CHECKING, Sequence from google.api_core.exceptions import AlreadyExists from google.api_core.gapic_v1.method import DEFAULT, _MethodDefault from google.cloud.orchestration.airflow.service_v1 import ImageVersion -from google.cloud.orchestration.airflow.service_v1.types import Environment +from google.cloud.orchestration.airflow.service_v1.types import Environment, ExecuteAirflowCommandResponse from airflow.configuration import conf from airflow.exceptions import AirflowException from airflow.providers.google.cloud.hooks.cloud_composer import CloudComposerHook from airflow.providers.google.cloud.links.base import BaseGoogleLink from airflow.providers.google.cloud.operators.cloud_base import GoogleCloudBaseOperator -from airflow.providers.google.cloud.triggers.cloud_composer import CloudComposerExecutionTrigger +from airflow.providers.google.cloud.triggers.cloud_composer import ( + CloudComposerAirflowCLICommandTrigger, + CloudComposerExecutionTrigger, +) from airflow.providers.google.common.consts import GOOGLE_DEFAULT_DEFERRABLE_METHOD_NAME if TYPE_CHECKING: @@ -651,3 +655,144 @@ def execute(self, context: Context): metadata=self.metadata, ) return [ImageVersion.to_dict(image) for image in result] + + +class CloudComposerRunAirflowCLICommandOperator(GoogleCloudBaseOperator): + """ + Run Airflow command for provided Composer environment. + + :param project_id: The ID of the Google Cloud project that the service belongs to. + :param region: The ID of the Google Cloud region that the service belongs to. + :param environment_id: The ID of the Google Cloud environment that the service belongs to. + :param command: Airflow command. + :param retry: Designation of what errors, if any, should be retried. + :param timeout: The timeout for this request. + :param metadata: Strings which should be sent along with the request as metadata. + :param gcp_conn_id: The connection ID used to connect to Google Cloud Platform. + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :param deferrable: Run operator in the deferrable mode + :param poll_interval: Optional: Control the rate of the poll for the result of deferrable run. + By default, the trigger will poll every 10 seconds. + """ + + template_fields = ( + "project_id", + "region", + "environment_id", + "command", + "impersonation_chain", + ) + + def __init__( + self, + *, + project_id: str, + region: str, + environment_id: str, + command: str, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: str | Sequence[str] | None = None, + deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False), + poll_interval: int = 10, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.project_id = project_id + self.region = region + self.environment_id = environment_id + self.command = command + self.retry = retry + self.timeout = timeout + self.metadata = metadata + self.gcp_conn_id = gcp_conn_id + self.impersonation_chain = impersonation_chain + self.deferrable = deferrable + self.poll_interval = poll_interval + + def execute(self, context: Context): + hook = CloudComposerHook( + gcp_conn_id=self.gcp_conn_id, + impersonation_chain=self.impersonation_chain, + ) + + self.log.info("Executing the command: [ airflow %s ]...", self.command) + + cmd, subcommand, parameters = self._parse_cmd_to_args(self.command) + execution_cmd_info = hook.execute_airflow_command( + project_id=self.project_id, + region=self.region, + environment_id=self.environment_id, + command=cmd, + subcommand=subcommand, + parameters=parameters, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + execution_cmd_info_dict = ExecuteAirflowCommandResponse.to_dict(execution_cmd_info) + + self.log.info("Command has been started. execution_id=%s", execution_cmd_info_dict["execution_id"]) + + if self.deferrable: + self.defer( + trigger=CloudComposerAirflowCLICommandTrigger( + project_id=self.project_id, + region=self.region, + environment_id=self.environment_id, + execution_cmd_info=execution_cmd_info_dict, + gcp_conn_id=self.gcp_conn_id, + impersonation_chain=self.impersonation_chain, + poll_interval=self.poll_interval, + ), + method_name=GOOGLE_DEFAULT_DEFERRABLE_METHOD_NAME, + ) + return + + result = hook.wait_command_execution_result( + project_id=self.project_id, + region=self.region, + environment_id=self.environment_id, + execution_cmd_info=execution_cmd_info_dict, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + poll_interval=self.poll_interval, + ) + result_str = self._merge_cmd_output_result(result) + self.log.info("Command execution result:\n%s", result_str) + return result + + def execute_complete(self, context: Context, event: dict) -> dict: + if event and event["status"] == "error": + raise AirflowException(event["message"]) + result: dict = event["result"] + result_str = self._merge_cmd_output_result(result) + self.log.info("Command execution result:\n%s", result_str) + return result + + def _parse_cmd_to_args(self, cmd: str) -> tuple: + """Parse user command to command, subcommand and parameters.""" + cmd_dict = shlex.split(cmd) + if not cmd_dict: + raise AirflowException("The provided command is empty.") + + command = cmd_dict[0] if len(cmd_dict) >= 1 else None + subcommand = cmd_dict[1] if len(cmd_dict) >= 2 else None + parameters = cmd_dict[2:] if len(cmd_dict) >= 3 else None + + return command, subcommand, parameters + + def _merge_cmd_output_result(self, result) -> str: + """Merge output to one string.""" + result_str = "\n".join(line_dict["content"] for line_dict in result["output"]) + return result_str diff --git a/airflow/providers/google/cloud/triggers/cloud_composer.py b/airflow/providers/google/cloud/triggers/cloud_composer.py index 4e52783a3791a..ac5a00c60f4a1 100644 --- a/airflow/providers/google/cloud/triggers/cloud_composer.py +++ b/airflow/providers/google/cloud/triggers/cloud_composer.py @@ -78,3 +78,71 @@ async def run(self): "operation_done": operation.done, } ) + + +class CloudComposerAirflowCLICommandTrigger(BaseTrigger): + """The trigger wait for the Airflow CLI command result.""" + + def __init__( + self, + project_id: str, + region: str, + environment_id: str, + execution_cmd_info: dict, + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: str | Sequence[str] | None = None, + poll_interval: int = 10, + ): + super().__init__() + self.project_id = project_id + self.region = region + self.environment_id = environment_id + self.execution_cmd_info = execution_cmd_info + self.gcp_conn_id = gcp_conn_id + self.impersonation_chain = impersonation_chain + self.poll_interval = poll_interval + + self.gcp_hook = CloudComposerAsyncHook( + gcp_conn_id=self.gcp_conn_id, + impersonation_chain=self.impersonation_chain, + ) + + def serialize(self) -> tuple[str, dict[str, Any]]: + return ( + "airflow.providers.google.cloud.triggers.cloud_composer.CloudComposerAirflowCLICommandTrigger", + { + "project_id": self.project_id, + "region": self.region, + "environment_id": self.environment_id, + "execution_cmd_info": self.execution_cmd_info, + "gcp_conn_id": self.gcp_conn_id, + "impersonation_chain": self.impersonation_chain, + "poll_interval": self.poll_interval, + }, + ) + + async def run(self): + try: + result = await self.gcp_hook.wait_command_execution_result( + project_id=self.project_id, + region=self.region, + environment_id=self.environment_id, + execution_cmd_info=self.execution_cmd_info, + poll_interval=self.poll_interval, + ) + except AirflowException as ex: + yield TriggerEvent( + { + "status": "error", + "message": str(ex), + } + ) + return + + yield TriggerEvent( + { + "status": "success", + "result": result, + } + ) + return diff --git a/docs/apache-airflow-providers-google/operators/cloud/cloud_composer.rst b/docs/apache-airflow-providers-google/operators/cloud/cloud_composer.rst index b55063c8f6a31..cdb9cb2931325 100644 --- a/docs/apache-airflow-providers-google/operators/cloud/cloud_composer.rst +++ b/docs/apache-airflow-providers-google/operators/cloud/cloud_composer.rst @@ -157,3 +157,23 @@ You can also list all supported Cloud Composer images: :dedent: 4 :start-after: [START howto_operator_composer_image_list] :end-before: [END howto_operator_composer_image_list] + +Run Airflow CLI commands +------------------------ + +You can run Airflow CLI commands in your environments, use: +:class:`~airflow.providers.google.cloud.operators.cloud_composer.CloudComposerRunAirflowCLICommandOperator` + +.. exampleinclude:: /../../tests/system/providers/google/cloud/composer/example_cloud_composer.py + :language: python + :dedent: 4 + :start-after: [START howto_operator_run_airflow_cli_command] + :end-before: [END howto_operator_run_airflow_cli_command] + +or you can define the same operator in the deferrable mode: + +.. exampleinclude:: /../../tests/system/providers/google/cloud/composer/example_cloud_composer.py + :language: python + :dedent: 4 + :start-after: [START howto_operator_run_airflow_cli_command_deferrable_mode] + :end-before: [END howto_operator_run_airflow_cli_command_deferrable_mode] diff --git a/tests/always/test_project_structure.py b/tests/always/test_project_structure.py index 3437092e6569f..75d824732bdd8 100644 --- a/tests/always/test_project_structure.py +++ b/tests/always/test_project_structure.py @@ -147,7 +147,6 @@ def test_providers_modules_should_have_tests(self): "tests/providers/google/cloud/transfers/test_mssql_to_gcs.py", "tests/providers/google/cloud/transfers/test_presto_to_gcs.py", "tests/providers/google/cloud/transfers/test_trino_to_gcs.py", - "tests/providers/google/cloud/triggers/test_cloud_composer.py", "tests/providers/google/cloud/utils/test_bigquery.py", "tests/providers/google/cloud/utils/test_bigquery_get_data.py", "tests/providers/google/cloud/utils/test_dataform.py", diff --git a/tests/providers/google/cloud/hooks/test_cloud_composer.py b/tests/providers/google/cloud/hooks/test_cloud_composer.py index 1072b310b2192..945a637605520 100644 --- a/tests/providers/google/cloud/hooks/test_cloud_composer.py +++ b/tests/providers/google/cloud/hooks/test_cloud_composer.py @@ -37,6 +37,12 @@ "software_config": {"image_version": "composer-1.17.7-airflow-2.1.4"}, }, } +TEST_COMMAND = "dags" +TEST_SUBCOMMAND = "list" +TEST_PARAMETERS = ["--verbose", "-o", "json"] +TEST_EXECUTION_ID = "test-execution-id" +TEST_POD = "test-pod" +TEST_POD_NAMESPACE = "test-namespace" TEST_UPDATE_MASK = {"paths": ["labels.label1"]} TEST_UPDATED_ENVIRONMENT = { @@ -197,6 +203,64 @@ def test_list_image_versions(self, mock_client) -> None: metadata=TEST_METADATA, ) + @mock.patch(COMPOSER_STRING.format("CloudComposerHook.get_environment_client")) + def test_execute_airflow_command(self, mock_client) -> None: + self.hook.execute_airflow_command( + project_id=TEST_GCP_PROJECT, + region=TEST_GCP_REGION, + environment_id=TEST_ENVIRONMENT_ID, + command=TEST_COMMAND, + subcommand=TEST_SUBCOMMAND, + parameters=TEST_PARAMETERS, + retry=TEST_RETRY, + timeout=TEST_TIMEOUT, + metadata=TEST_METADATA, + ) + mock_client.assert_called_once() + mock_client.return_value.execute_airflow_command.assert_called_once_with( + request={ + "environment": self.hook.get_environment_name( + TEST_GCP_PROJECT, TEST_GCP_REGION, TEST_ENVIRONMENT_ID + ), + "command": TEST_COMMAND, + "subcommand": TEST_SUBCOMMAND, + "parameters": TEST_PARAMETERS, + }, + retry=TEST_RETRY, + timeout=TEST_TIMEOUT, + metadata=TEST_METADATA, + ) + + @mock.patch(COMPOSER_STRING.format("CloudComposerHook.get_environment_client")) + def test_poll_airflow_command(self, mock_client) -> None: + self.hook.poll_airflow_command( + project_id=TEST_GCP_PROJECT, + region=TEST_GCP_REGION, + environment_id=TEST_ENVIRONMENT_ID, + execution_id=TEST_EXECUTION_ID, + pod=TEST_POD, + pod_namespace=TEST_POD_NAMESPACE, + next_line_number=1, + retry=TEST_RETRY, + timeout=TEST_TIMEOUT, + metadata=TEST_METADATA, + ) + mock_client.assert_called_once() + mock_client.return_value.poll_airflow_command.assert_called_once_with( + request={ + "environment": self.hook.get_environment_name( + TEST_GCP_PROJECT, TEST_GCP_REGION, TEST_ENVIRONMENT_ID + ), + "execution_id": TEST_EXECUTION_ID, + "pod": TEST_POD, + "pod_namespace": TEST_POD_NAMESPACE, + "next_line_number": 1, + }, + retry=TEST_RETRY, + timeout=TEST_TIMEOUT, + metadata=TEST_METADATA, + ) + class TestCloudComposerAsyncHook: def test_delegate_to_runtime_error(self): @@ -282,3 +346,67 @@ async def test_update_environment(self, mock_client) -> None: timeout=TEST_TIMEOUT, metadata=TEST_METADATA, ) + + @pytest.mark.asyncio + @mock.patch(COMPOSER_STRING.format("CloudComposerAsyncHook.get_environment_client")) + async def test_execute_airflow_command(self, mock_client) -> None: + mock_env_client = AsyncMock(EnvironmentsAsyncClient) + mock_client.return_value = mock_env_client + await self.hook.execute_airflow_command( + project_id=TEST_GCP_PROJECT, + region=TEST_GCP_REGION, + environment_id=TEST_ENVIRONMENT_ID, + command=TEST_COMMAND, + subcommand=TEST_SUBCOMMAND, + parameters=TEST_PARAMETERS, + retry=TEST_RETRY, + timeout=TEST_TIMEOUT, + metadata=TEST_METADATA, + ) + mock_client.assert_called_once() + mock_client.return_value.execute_airflow_command.assert_called_once_with( + request={ + "environment": self.hook.get_environment_name( + TEST_GCP_PROJECT, TEST_GCP_REGION, TEST_ENVIRONMENT_ID + ), + "command": TEST_COMMAND, + "subcommand": TEST_SUBCOMMAND, + "parameters": TEST_PARAMETERS, + }, + retry=TEST_RETRY, + timeout=TEST_TIMEOUT, + metadata=TEST_METADATA, + ) + + @pytest.mark.asyncio + @mock.patch(COMPOSER_STRING.format("CloudComposerAsyncHook.get_environment_client")) + async def test_poll_airflow_command(self, mock_client) -> None: + mock_env_client = AsyncMock(EnvironmentsAsyncClient) + mock_client.return_value = mock_env_client + await self.hook.poll_airflow_command( + project_id=TEST_GCP_PROJECT, + region=TEST_GCP_REGION, + environment_id=TEST_ENVIRONMENT_ID, + execution_id=TEST_EXECUTION_ID, + pod=TEST_POD, + pod_namespace=TEST_POD_NAMESPACE, + next_line_number=1, + retry=TEST_RETRY, + timeout=TEST_TIMEOUT, + metadata=TEST_METADATA, + ) + mock_client.assert_called_once() + mock_client.return_value.poll_airflow_command.assert_called_once_with( + request={ + "environment": self.hook.get_environment_name( + TEST_GCP_PROJECT, TEST_GCP_REGION, TEST_ENVIRONMENT_ID + ), + "execution_id": TEST_EXECUTION_ID, + "pod": TEST_POD, + "pod_namespace": TEST_POD_NAMESPACE, + "next_line_number": 1, + }, + retry=TEST_RETRY, + timeout=TEST_TIMEOUT, + metadata=TEST_METADATA, + ) diff --git a/tests/providers/google/cloud/operators/test_cloud_composer.py b/tests/providers/google/cloud/operators/test_cloud_composer.py index 7bda6fce5a24b..e882db9526a01 100644 --- a/tests/providers/google/cloud/operators/test_cloud_composer.py +++ b/tests/providers/google/cloud/operators/test_cloud_composer.py @@ -28,9 +28,13 @@ CloudComposerGetEnvironmentOperator, CloudComposerListEnvironmentsOperator, CloudComposerListImageVersionsOperator, + CloudComposerRunAirflowCLICommandOperator, CloudComposerUpdateEnvironmentOperator, ) -from airflow.providers.google.cloud.triggers.cloud_composer import CloudComposerExecutionTrigger +from airflow.providers.google.cloud.triggers.cloud_composer import ( + CloudComposerAirflowCLICommandTrigger, + CloudComposerExecutionTrigger, +) from airflow.providers.google.common.consts import GOOGLE_DEFAULT_DEFERRABLE_METHOD_NAME TASK_ID = "task-id" @@ -46,6 +50,10 @@ "software_config": {"image_version": "composer-1.17.7-airflow-2.1.4"}, }, } +TEST_USER_COMMAND = "dags list -o json --verbose" +TEST_COMMAND = "dags" +TEST_SUBCOMMAND = "list" +TEST_PARAMETERS = ["-o", "json", "--verbose"] TEST_UPDATE_MASK = {"paths": ["labels.label1"]} TEST_UPDATED_ENVIRONMENT = { @@ -305,3 +313,59 @@ def test_assert_valid_hook_call(self, mock_hook) -> None: timeout=TEST_TIMEOUT, metadata=TEST_METADATA, ) + + +class TestCloudComposerRunAirflowCLICommandOperator: + @mock.patch(COMPOSER_STRING.format("ExecuteAirflowCommandResponse.to_dict")) + @mock.patch(COMPOSER_STRING.format("CloudComposerHook")) + def test_execute(self, mock_hook, to_dict_mode) -> None: + op = CloudComposerRunAirflowCLICommandOperator( + task_id=TASK_ID, + project_id=TEST_GCP_PROJECT, + region=TEST_GCP_REGION, + environment_id=TEST_ENVIRONMENT_ID, + command=TEST_USER_COMMAND, + gcp_conn_id=TEST_GCP_CONN_ID, + retry=TEST_RETRY, + timeout=TEST_TIMEOUT, + metadata=TEST_METADATA, + ) + op.execute(mock.MagicMock()) + mock_hook.assert_called_once_with( + gcp_conn_id=TEST_GCP_CONN_ID, + impersonation_chain=TEST_IMPERSONATION_CHAIN, + ) + mock_hook.return_value.execute_airflow_command.assert_called_once_with( + project_id=TEST_GCP_PROJECT, + region=TEST_GCP_REGION, + environment_id=TEST_ENVIRONMENT_ID, + command=TEST_COMMAND, + subcommand=TEST_SUBCOMMAND, + parameters=TEST_PARAMETERS, + retry=TEST_RETRY, + timeout=TEST_TIMEOUT, + metadata=TEST_METADATA, + ) + + @mock.patch(COMPOSER_STRING.format("ExecuteAirflowCommandResponse.to_dict")) + @mock.patch(COMPOSER_STRING.format("CloudComposerHook")) + @mock.patch(COMPOSER_TRIGGERS_STRING.format("CloudComposerAsyncHook")) + def test_execute_deferrable(self, mock_trigger_hook, mock_hook, to_dict_mode): + op = CloudComposerRunAirflowCLICommandOperator( + task_id=TASK_ID, + project_id=TEST_GCP_PROJECT, + region=TEST_GCP_REGION, + environment_id=TEST_ENVIRONMENT_ID, + command=TEST_USER_COMMAND, + gcp_conn_id=TEST_GCP_CONN_ID, + retry=TEST_RETRY, + timeout=TEST_TIMEOUT, + metadata=TEST_METADATA, + deferrable=True, + ) + + with pytest.raises(TaskDeferred) as exc: + op.execute(mock.MagicMock()) + + assert isinstance(exc.value.trigger, CloudComposerAirflowCLICommandTrigger) + assert exc.value.method_name == GOOGLE_DEFAULT_DEFERRABLE_METHOD_NAME diff --git a/tests/providers/google/cloud/triggers/test_cloud_composer.py b/tests/providers/google/cloud/triggers/test_cloud_composer.py new file mode 100644 index 0000000000000..99daaf83bdc26 --- /dev/null +++ b/tests/providers/google/cloud/triggers/test_cloud_composer.py @@ -0,0 +1,96 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +from unittest import mock + +import pytest + +from airflow.models import Connection +from airflow.providers.google.cloud.triggers.cloud_composer import CloudComposerAirflowCLICommandTrigger +from airflow.triggers.base import TriggerEvent + +TEST_PROJECT_ID = "test-project-id" +TEST_LOCATION = "us-central1" +TEST_ENVIRONMENT_ID = "testenvname" +TEST_EXEC_CMD_INFO = { + "execution_id": "test_id", + "pod": "test_pod", + "pod_namespace": "test_namespace", + "error": "test_error", +} +TEST_GCP_CONN_ID = "test_gcp_conn_id" +TEST_POLL_INTERVAL = 10 +TEST_IMPERSONATION_CHAIN = "test_impersonation_chain" +TEST_EXEC_RESULT = { + "output": [{"line_number": 1, "content": "test_content"}], + "output_end": True, + "exit_info": {"exit_code": 0, "error": ""}, +} + + +@pytest.fixture +@mock.patch( + "airflow.providers.google.common.hooks.base_google.GoogleBaseHook.get_connection", + return_value=Connection(conn_id="test_conn"), +) +def trigger(mock_conn): + return CloudComposerAirflowCLICommandTrigger( + project_id=TEST_PROJECT_ID, + region=TEST_LOCATION, + environment_id=TEST_ENVIRONMENT_ID, + execution_cmd_info=TEST_EXEC_CMD_INFO, + gcp_conn_id=TEST_GCP_CONN_ID, + impersonation_chain=TEST_IMPERSONATION_CHAIN, + poll_interval=TEST_POLL_INTERVAL, + ) + + +class TestCloudComposerAirflowCLICommandTrigger: + def test_serialize(self, trigger): + actual_data = trigger.serialize() + expected_data = ( + "airflow.providers.google.cloud.triggers.cloud_composer.CloudComposerAirflowCLICommandTrigger", + { + "project_id": TEST_PROJECT_ID, + "region": TEST_LOCATION, + "environment_id": TEST_ENVIRONMENT_ID, + "execution_cmd_info": TEST_EXEC_CMD_INFO, + "gcp_conn_id": TEST_GCP_CONN_ID, + "impersonation_chain": TEST_IMPERSONATION_CHAIN, + "poll_interval": TEST_POLL_INTERVAL, + }, + ) + assert actual_data == expected_data + + @pytest.mark.asyncio + @mock.patch( + "airflow.providers.google.cloud.hooks.cloud_composer.CloudComposerAsyncHook.wait_command_execution_result" + ) + async def test_run(self, mock_exec_result, trigger): + mock_exec_result.return_value = TEST_EXEC_RESULT + + expected_event = TriggerEvent( + { + "status": "success", + "result": TEST_EXEC_RESULT, + } + ) + actual_event = await trigger.run().asend(None) + + assert actual_event == expected_event diff --git a/tests/system/providers/google/cloud/composer/example_cloud_composer.py b/tests/system/providers/google/cloud/composer/example_cloud_composer.py index fb8958412e637..fe60c56ddf812 100644 --- a/tests/system/providers/google/cloud/composer/example_cloud_composer.py +++ b/tests/system/providers/google/cloud/composer/example_cloud_composer.py @@ -28,6 +28,7 @@ CloudComposerGetEnvironmentOperator, CloudComposerListEnvironmentsOperator, CloudComposerListImageVersionsOperator, + CloudComposerRunAirflowCLICommandOperator, CloudComposerUpdateEnvironmentOperator, ) from airflow.utils.trigger_rule import TriggerRule @@ -59,6 +60,8 @@ UPDATE_MASK = {"paths": ["labels.label"]} # [END howto_operator_composer_update_environment] +COMMAND = "dags list -o json --verbose" + with DAG( DAG_ID, @@ -134,6 +137,27 @@ ) # [END howto_operator_update_composer_environment_deferrable_mode] + # [START howto_operator_run_airflow_cli_command] + run_airflow_cli_cmd = CloudComposerRunAirflowCLICommandOperator( + task_id="run_airflow_cli_cmd", + project_id=PROJECT_ID, + region=REGION, + environment_id=ENVIRONMENT_ID, + command=COMMAND, + ) + # [END howto_operator_run_airflow_cli_command] + + # [START howto_operator_run_airflow_cli_command_deferrable_mode] + defer_run_airflow_cli_cmd = CloudComposerRunAirflowCLICommandOperator( + task_id="defer_run_airflow_cli_cmd", + project_id=PROJECT_ID, + region=REGION, + environment_id=ENVIRONMENT_ID_ASYNC, + command=COMMAND, + deferrable=True, + ) + # [END howto_operator_run_airflow_cli_command_deferrable_mode] + # [START howto_operator_delete_composer_environment] delete_env = CloudComposerDeleteEnvironmentOperator( task_id="delete_env", @@ -161,6 +185,7 @@ list_envs, get_env, [update_env, defer_update_env], + [run_airflow_cli_cmd, defer_run_airflow_cli_cmd], [delete_env, defer_delete_env], )