Skip to content

Commit

Permalink
Create CloudComposerRunAirflowCLICommandOperator operator (#38965)
Browse files Browse the repository at this point in the history
  • Loading branch information
MaksYermak authored Apr 24, 2024
1 parent fd1b5f5 commit c434c6b
Show file tree
Hide file tree
Showing 9 changed files with 799 additions and 6 deletions.
252 changes: 250 additions & 2 deletions airflow/providers/google/cloud/hooks/cloud_composer.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,17 @@
# under the License.
from __future__ import annotations

from typing import TYPE_CHECKING, Sequence
import asyncio
import time
from typing import TYPE_CHECKING, MutableSequence, Sequence

from google.api_core.client_options import ClientOptions
from google.api_core.gapic_v1.method import DEFAULT, _MethodDefault
from google.cloud.orchestration.airflow.service_v1 import (
EnvironmentsAsyncClient,
EnvironmentsClient,
ImageVersionsClient,
PollAirflowCommandResponse,
)

from airflow.exceptions import AirflowException
Expand All @@ -42,7 +45,10 @@
from google.cloud.orchestration.airflow.service_v1.services.image_versions.pagers import (
ListImageVersionsPager,
)
from google.cloud.orchestration.airflow.service_v1.types import Environment
from google.cloud.orchestration.airflow.service_v1.types import (
Environment,
ExecuteAirflowCommandResponse,
)
from google.protobuf.field_mask_pb2 import FieldMask


Expand Down Expand Up @@ -294,6 +300,127 @@ def list_image_versions(
)
return result

@GoogleBaseHook.fallback_to_default_project_id
def execute_airflow_command(
self,
project_id: str,
region: str,
environment_id: str,
command: str,
subcommand: str,
parameters: MutableSequence[str],
retry: Retry | _MethodDefault = DEFAULT,
timeout: float | None = None,
metadata: Sequence[tuple[str, str]] = (),
) -> ExecuteAirflowCommandResponse:
"""
Execute Airflow command for provided Composer environment.
:param project_id: The ID of the Google Cloud project that the service belongs to.
:param region: The ID of the Google Cloud region that the service belongs to.
:param environment_id: The ID of the Google Cloud environment that the service belongs to.
:param command: Airflow command.
:param subcommand: Airflow subcommand.
:param parameters: Parameters for the Airflow command/subcommand as an array of arguments. It may
contain positional arguments like ``["my-dag-id"]``, key-value parameters like ``["--foo=bar"]``
or ``["--foo","bar"]``, or other flags like ``["-f"]``.
:param retry: Designation of what errors, if any, should be retried.
:param timeout: The timeout for this request.
:param metadata: Strings which should be sent along with the request as metadata.
"""
client = self.get_environment_client()
result = client.execute_airflow_command(
request={
"environment": self.get_environment_name(project_id, region, environment_id),
"command": command,
"subcommand": subcommand,
"parameters": parameters,
},
retry=retry,
timeout=timeout,
metadata=metadata,
)
return result

@GoogleBaseHook.fallback_to_default_project_id
def poll_airflow_command(
self,
project_id: str,
region: str,
environment_id: str,
execution_id: str,
pod: str,
pod_namespace: str,
next_line_number: int,
retry: Retry | _MethodDefault = DEFAULT,
timeout: float | None = None,
metadata: Sequence[tuple[str, str]] = (),
) -> PollAirflowCommandResponse:
"""
Poll Airflow command execution result for provided Composer environment.
:param project_id: The ID of the Google Cloud project that the service belongs to.
:param region: The ID of the Google Cloud region that the service belongs to.
:param environment_id: The ID of the Google Cloud environment that the service belongs to.
:param execution_id: The unique ID of the command execution.
:param pod: The name of the pod where the command is executed.
:param pod_namespace: The namespace of the pod where the command is executed.
:param next_line_number: Line number from which new logs should be fetched.
:param retry: Designation of what errors, if any, should be retried.
:param timeout: The timeout for this request.
:param metadata: Strings which should be sent along with the request as metadata.
"""
client = self.get_environment_client()
result = client.poll_airflow_command(
request={
"environment": self.get_environment_name(project_id, region, environment_id),
"execution_id": execution_id,
"pod": pod,
"pod_namespace": pod_namespace,
"next_line_number": next_line_number,
},
retry=retry,
timeout=timeout,
metadata=metadata,
)
return result

def wait_command_execution_result(
self,
project_id: str,
region: str,
environment_id: str,
execution_cmd_info: dict,
retry: Retry | _MethodDefault = DEFAULT,
timeout: float | None = None,
metadata: Sequence[tuple[str, str]] = (),
poll_interval: int = 10,
) -> dict:
while True:
try:
result = self.poll_airflow_command(
project_id=project_id,
region=region,
environment_id=environment_id,
execution_id=execution_cmd_info["execution_id"],
pod=execution_cmd_info["pod"],
pod_namespace=execution_cmd_info["pod_namespace"],
next_line_number=1,
retry=retry,
timeout=timeout,
metadata=metadata,
)
except Exception as ex:
self.log.exception("Exception occurred while polling CMD result")
raise AirflowException(ex)

result_dict = PollAirflowCommandResponse.to_dict(result)
if result_dict["output_end"]:
return result_dict

self.log.info("Waiting for result...")
time.sleep(poll_interval)


class CloudComposerAsyncHook(GoogleBaseHook):
"""Hook for Google Cloud Composer async APIs."""
Expand Down Expand Up @@ -421,3 +548,124 @@ async def update_environment(
timeout=timeout,
metadata=metadata,
)

@GoogleBaseHook.fallback_to_default_project_id
async def execute_airflow_command(
self,
project_id: str,
region: str,
environment_id: str,
command: str,
subcommand: str,
parameters: MutableSequence[str],
retry: AsyncRetry | _MethodDefault = DEFAULT,
timeout: float | None = None,
metadata: Sequence[tuple[str, str]] = (),
) -> AsyncOperation:
"""
Execute Airflow command for provided Composer environment.
:param project_id: The ID of the Google Cloud project that the service belongs to.
:param region: The ID of the Google Cloud region that the service belongs to.
:param environment_id: The ID of the Google Cloud environment that the service belongs to.
:param command: Airflow command.
:param subcommand: Airflow subcommand.
:param parameters: Parameters for the Airflow command/subcommand as an array of arguments. It may
contain positional arguments like ``["my-dag-id"]``, key-value parameters like ``["--foo=bar"]``
or ``["--foo","bar"]``, or other flags like ``["-f"]``.
:param retry: Designation of what errors, if any, should be retried.
:param timeout: The timeout for this request.
:param metadata: Strings which should be sent along with the request as metadata.
"""
client = self.get_environment_client()

return await client.execute_airflow_command(
request={
"environment": self.get_environment_name(project_id, region, environment_id),
"command": command,
"subcommand": subcommand,
"parameters": parameters,
},
retry=retry,
timeout=timeout,
metadata=metadata,
)

@GoogleBaseHook.fallback_to_default_project_id
async def poll_airflow_command(
self,
project_id: str,
region: str,
environment_id: str,
execution_id: str,
pod: str,
pod_namespace: str,
next_line_number: int,
retry: AsyncRetry | _MethodDefault = DEFAULT,
timeout: float | None = None,
metadata: Sequence[tuple[str, str]] = (),
) -> AsyncOperation:
"""
Poll Airflow command execution result for provided Composer environment.
:param project_id: The ID of the Google Cloud project that the service belongs to.
:param region: The ID of the Google Cloud region that the service belongs to.
:param environment_id: The ID of the Google Cloud environment that the service belongs to.
:param execution_id: The unique ID of the command execution.
:param pod: The name of the pod where the command is executed.
:param pod_namespace: The namespace of the pod where the command is executed.
:param next_line_number: Line number from which new logs should be fetched.
:param retry: Designation of what errors, if any, should be retried.
:param timeout: The timeout for this request.
:param metadata: Strings which should be sent along with the request as metadata.
"""
client = self.get_environment_client()

return await client.poll_airflow_command(
request={
"environment": self.get_environment_name(project_id, region, environment_id),
"execution_id": execution_id,
"pod": pod,
"pod_namespace": pod_namespace,
"next_line_number": next_line_number,
},
retry=retry,
timeout=timeout,
metadata=metadata,
)

async def wait_command_execution_result(
self,
project_id: str,
region: str,
environment_id: str,
execution_cmd_info: dict,
retry: AsyncRetry | _MethodDefault = DEFAULT,
timeout: float | None = None,
metadata: Sequence[tuple[str, str]] = (),
poll_interval: int = 10,
) -> dict:
while True:
try:
result = await self.poll_airflow_command(
project_id=project_id,
region=region,
environment_id=environment_id,
execution_id=execution_cmd_info["execution_id"],
pod=execution_cmd_info["pod"],
pod_namespace=execution_cmd_info["pod_namespace"],
next_line_number=1,
retry=retry,
timeout=timeout,
metadata=metadata,
)
except Exception as ex:
self.log.exception("Exception occurred while polling CMD result")
raise AirflowException(ex)

result_dict = PollAirflowCommandResponse.to_dict(result)
if result_dict["output_end"]:
return result_dict

self.log.info("Sleeping for %s seconds.", poll_interval)
await asyncio.sleep(poll_interval)
Loading

0 comments on commit c434c6b

Please sign in to comment.