Skip to content

Commit

Permalink
Add deferrable mode to DataprocInstantiateWorkflowTemplateOperator (#…
Browse files Browse the repository at this point in the history
…28618)

Co-authored-by: Beata Kossakowska <[email protected]>
  • Loading branch information
bkossakowska and Beata Kossakowska authored Feb 20, 2023
1 parent c5d548b commit 1677d80
Show file tree
Hide file tree
Showing 8 changed files with 303 additions and 15 deletions.
8 changes: 8 additions & 0 deletions airflow/providers/google/cloud/hooks/dataproc.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from google.api_core.gapic_v1.method import DEFAULT, _MethodDefault
from google.api_core.operation import Operation
from google.api_core.operation_async import AsyncOperation
from google.api_core.operations_v1.operations_client import OperationsClient
from google.api_core.retry import Retry
from google.cloud.dataproc_v1 import (
Batch,
Expand Down Expand Up @@ -1047,6 +1048,10 @@ def get_batch_client(self, region: str | None = None) -> BatchControllerAsyncCli
credentials=self.get_credentials(), client_info=CLIENT_INFO, client_options=client_options
)

def get_operations_client(self, region: str) -> OperationsClient:
"""Returns OperationsClient"""
return self.get_template_client(region=region).transport.operations_client

@GoogleBaseHook.fallback_to_default_project_id
async def create_cluster(
self,
Expand Down Expand Up @@ -1459,6 +1464,9 @@ async def instantiate_inline_workflow_template(
)
return operation

async def get_operation(self, region, operation_name):
return await self.get_operations_client(region).get_operation(name=operation_name)

@GoogleBaseHook.fallback_to_default_project_id
async def get_job(
self,
Expand Down
42 changes: 38 additions & 4 deletions airflow/providers/google/cloud/operators/dataproc.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
DataprocClusterTrigger,
DataprocDeleteClusterTrigger,
DataprocSubmitTrigger,
DataprocWorkflowTrigger,
)
from airflow.utils import timezone

Expand Down Expand Up @@ -1688,7 +1689,7 @@ class DataprocInstantiateWorkflowTemplateOperator(BaseOperator):
.. seealso::
Please refer to:
https://cloud.google.com/dataproc/docs/reference/rest/v1beta2/projects.regions.workflowTemplates/instantiate
https://cloud.google.com/dataproc/docs/reference/rest/v1/projects.regions.workflowTemplates/instantiate
:param template_id: The id of the template. (templated)
:param project_id: The ID of the google cloud project in which
Expand Down Expand Up @@ -1717,6 +1718,8 @@ class DataprocInstantiateWorkflowTemplateOperator(BaseOperator):
If set as a sequence, the identities from the list must grant
Service Account Token Creator IAM role to the directly preceding identity, with first
account from the list granting this role to the originating account (templated).
:param deferrable: Run operator in the deferrable mode.
:param polling_interval_seconds: Time (seconds) to wait between calls to check the run status.
"""

template_fields: Sequence[str] = ("template_id", "impersonation_chain", "request_id", "parameters")
Expand All @@ -1737,10 +1740,13 @@ def __init__(
metadata: Sequence[tuple[str, str]] = (),
gcp_conn_id: str = "google_cloud_default",
impersonation_chain: str | Sequence[str] | None = None,
deferrable: bool = False,
polling_interval_seconds: int = 10,
**kwargs,
) -> None:
super().__init__(**kwargs)

if deferrable and polling_interval_seconds <= 0:
raise ValueError("Invalid value for polling_interval_seconds. Expected value greater than 0")
self.template_id = template_id
self.parameters = parameters
self.version = version
Expand All @@ -1752,6 +1758,8 @@ def __init__(
self.request_id = request_id
self.gcp_conn_id = gcp_conn_id
self.impersonation_chain = impersonation_chain
self.deferrable = deferrable
self.polling_interval_seconds = polling_interval_seconds

def execute(self, context: Context):
hook = DataprocHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain)
Expand All @@ -1772,8 +1780,34 @@ def execute(self, context: Context):
context=context, task_instance=self, url=DATAPROC_WORKFLOW_LINK, resource=self.workflow_id
)
self.log.info("Template instantiated. Workflow Id : %s", self.workflow_id)
operation.result()
self.log.info("Workflow %s completed successfully", self.workflow_id)
if not self.deferrable:
hook.wait_for_operation(timeout=self.timeout, result_retry=self.retry, operation=operation)
self.log.info("Workflow %s completed successfully", self.workflow_id)
else:
self.defer(
trigger=DataprocWorkflowTrigger(
template_name=self.template_id,
name=operation.operation.name,
project_id=self.project_id,
region=self.region,
gcp_conn_id=self.gcp_conn_id,
impersonation_chain=self.impersonation_chain,
polling_interval_seconds=self.polling_interval_seconds,
),
method_name="execute_complete",
)

def execute_complete(self, context, event=None) -> None:
"""
Callback for when the trigger fires - returns immediately.
Relies on trigger to throw an exception, otherwise it assumes execution was
successful.
"""
if event["status"] == "failed" or event["status"] == "error":
self.log.exception("Unexpected error in the operation.")
raise AirflowException(event["message"])

self.log.info("Workflow %s completed successfully", event["operation_name"])


class DataprocInstantiateInlineWorkflowTemplateOperator(BaseOperator):
Expand Down
90 changes: 90 additions & 0 deletions airflow/providers/google/cloud/triggers/dataproc.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,3 +310,93 @@ def _get_hook(self) -> DataprocAsyncHook:
gcp_conn_id=self.gcp_conn_id,
impersonation_chain=self.impersonation_chain,
)


class DataprocWorkflowTrigger(BaseTrigger):
"""
Trigger that periodically polls information from Dataproc API to verify status.
Implementation leverages asynchronous transport.
"""

def __init__(
self,
template_name: str,
name: str,
region: str,
project_id: str | None = None,
gcp_conn_id: str = "google_cloud_default",
impersonation_chain: str | Sequence[str] | None = None,
delegate_to: str | None = None,
polling_interval_seconds: int = 10,
):
super().__init__()
self.gcp_conn_id = gcp_conn_id
self.template_name = template_name
self.name = name
self.impersonation_chain = impersonation_chain
self.project_id = project_id
self.region = region
self.polling_interval_seconds = polling_interval_seconds
self.delegate_to = delegate_to
if delegate_to:
warnings.warn(
"'delegate_to' parameter is deprecated, please use 'impersonation_chain'", DeprecationWarning
)

def serialize(self):
return (
"airflow.providers.google.cloud.triggers.dataproc.DataprocWorkflowTrigger",
{
"template_name": self.template_name,
"name": self.name,
"project_id": self.project_id,
"region": self.region,
"gcp_conn_id": self.gcp_conn_id,
"delegate_to": self.delegate_to,
"impersonation_chain": self.impersonation_chain,
"polling_interval_seconds": self.polling_interval_seconds,
},
)

async def run(self) -> AsyncIterator["TriggerEvent"]:
hook = self._get_hook()
while True:
try:
operation = await hook.get_operation(region=self.region, operation_name=self.name)
if operation.done:
if operation.error.message:
yield TriggerEvent(
{
"operation_name": operation.name,
"operation_done": operation.done,
"status": "error",
"message": operation.error.message,
}
)
return
yield TriggerEvent(
{
"operation_name": operation.name,
"operation_done": operation.done,
"status": "success",
"message": "Operation is successfully ended.",
}
)
return
else:
self.log.info("Sleeping for %s seconds.", self.polling_interval_seconds)
await asyncio.sleep(self.polling_interval_seconds)
except Exception as e:
self.log.exception("Exception occurred while checking operation status.")
yield TriggerEvent(
{
"status": "failed",
"message": str(e),
}
)

def _get_hook(self) -> DataprocAsyncHook: # type: ignore[override]
return DataprocAsyncHook(
gcp_conn_id=self.gcp_conn_id,
impersonation_chain=self.impersonation_chain,
)
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,14 @@ Once a workflow is created users can trigger it using
:start-after: [START how_to_cloud_dataproc_trigger_workflow_template]
:end-before: [END how_to_cloud_dataproc_trigger_workflow_template]

Also for all this action you can use operator in the deferrable mode:

.. exampleinclude:: /../../tests/system/providers/google/cloud/dataproc/example_dataproc_workflow.py
:language: python
:dedent: 4
:start-after: [START how_to_cloud_dataproc_trigger_workflow_template_async]
:end-before: [END how_to_cloud_dataproc_trigger_workflow_template_async]

The inline operator is an alternative. It creates a workflow, run it, and delete it afterwards:
:class:`~airflow.providers.google.cloud.operators.dataproc.DataprocInstantiateInlineWorkflowTemplateOperator`:

Expand Down
11 changes: 11 additions & 0 deletions tests/providers/google/cloud/hooks/test_dataproc.py
Original file line number Diff line number Diff line change
Expand Up @@ -729,6 +729,17 @@ async def test_instantiate_workflow_template(self, mock_client):
metadata=(),
)

@pytest.mark.asyncio
@async_mock.patch(DATAPROC_STRING.format("DataprocAsyncHook.get_operation"))
async def test_get_operation(self, mock_client):
mock_client.return_value = None
hook = DataprocAsyncHook(
gcp_conn_id="google_cloud_default", delegate_to=None, impersonation_chain=None
)
await hook.get_operation(region=GCP_LOCATION, operation_name="operation_name")
mock_client.assert_called_once()
mock_client.assert_called_with(region=GCP_LOCATION, operation_name="operation_name")

@mock.patch(DATAPROC_STRING.format("DataprocAsyncHook.get_template_client"))
def test_instantiate_workflow_template_missing_region(self, mock_client):
with pytest.raises(TypeError):
Expand Down
32 changes: 32 additions & 0 deletions tests/providers/google/cloud/operators/test_dataproc.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@
DataprocClusterTrigger,
DataprocDeleteClusterTrigger,
DataprocSubmitTrigger,
DataprocWorkflowTrigger,
)
from airflow.providers.google.common.consts import GOOGLE_DEFAULT_DEFERRABLE_METHOD_NAME
from airflow.serialization.serialized_objects import SerializedDAG
Expand Down Expand Up @@ -441,6 +442,7 @@ def test_deprecation_warning(self):
@mock.patch(DATAPROC_PATH.format("DataprocHook"))
def test_execute(self, mock_hook, to_dict_mock):
self.extra_links_manager_mock.attach_mock(mock_hook, "hook")
mock_hook.return_value.create_cluster.result.return_value = None
create_cluster_args = {
"region": GCP_REGION,
"project_id": GCP_PROJECT,
Expand Down Expand Up @@ -1363,6 +1365,36 @@ def test_execute(self, mock_hook):
metadata=METADATA,
)

@mock.patch(DATAPROC_PATH.format("DataprocHook"))
@mock.patch(DATAPROC_TRIGGERS_PATH.format("DataprocAsyncHook"))
def test_execute_call_defer_method(self, mock_trigger_hook, mock_hook):
operator = DataprocInstantiateWorkflowTemplateOperator(
task_id=TASK_ID,
template_id=TEMPLATE_ID,
region=GCP_REGION,
project_id=GCP_PROJECT,
version=2,
parameters={},
request_id=REQUEST_ID,
retry=RETRY,
timeout=TIMEOUT,
metadata=METADATA,
gcp_conn_id=GCP_CONN_ID,
impersonation_chain=IMPERSONATION_CHAIN,
deferrable=True,
)

with pytest.raises(TaskDeferred) as exc:
operator.execute(mock.MagicMock())

mock_hook.assert_called_once_with(gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN)

mock_hook.return_value.instantiate_workflow_template.assert_called_once()

mock_hook.return_value.wait_for_operation.assert_not_called()
assert isinstance(exc.value.trigger, DataprocWorkflowTrigger)
assert exc.value.method_name == GOOGLE_DEFAULT_DEFERRABLE_METHOD_NAME


@pytest.mark.need_serialized_dag
@mock.patch(DATAPROC_PATH.format("DataprocHook"))
Expand Down
Loading

0 comments on commit 1677d80

Please sign in to comment.