Skip to content

Commit

Permalink
Handle create cluster error gracefully (#448)
Browse files Browse the repository at this point in the history
  • Loading branch information
pankajastro authored Jun 22, 2022
1 parent 1eaaa79 commit e19c656
Show file tree
Hide file tree
Showing 4 changed files with 403 additions and 20 deletions.
32 changes: 22 additions & 10 deletions astronomer/providers/google/cloud/operators/dataproc.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
DataprocUpdateClusterOperator,
)
from airflow.utils.context import Context
from google.api_core.exceptions import AlreadyExists

from astronomer.providers.google.cloud.triggers.dataproc import (
DataprocCreateClusterTrigger,
Expand Down Expand Up @@ -74,17 +75,25 @@ def __init__(
def execute(self, context: Context) -> None: # type: ignore[override]
"""Call create cluster API and defer to DataprocCreateClusterTrigger to check the status"""
hook = DataprocHook(gcp_conn_id=self.gcp_conn_id)
hook.create_cluster(
region=self.region,
project_id=self.project_id,
cluster_name=self.cluster_name,
cluster_config=self.cluster_config,
labels=self.labels,
request_id=self.request_id,
retry=self.retry,
timeout=self.timeout,
metadata=self.metadata,
DataprocLink.persist(
context=context, task_instance=self, url=DATAPROC_CLUSTER_LINK, resource=self.cluster_name
)
try:
hook.create_cluster(
region=self.region,
project_id=self.project_id,
cluster_name=self.cluster_name,
cluster_config=self.cluster_config,
labels=self.labels,
request_id=self.request_id,
retry=self.retry,
timeout=self.timeout,
metadata=self.metadata,
)
except AlreadyExists:
if not self.use_if_exists:
raise
self.log.info("Cluster already exists.")

end_time: float = time.monotonic() + self.timeout

Expand All @@ -95,6 +104,9 @@ def execute(self, context: Context) -> None: # type: ignore[override]
cluster_name=self.cluster_name,
end_time=end_time,
metadata=self.metadata,
delete_on_error=self.delete_on_error,
cluster_config=self.cluster_config,
labels=self.labels,
gcp_conn_id=self.gcp_conn_id,
polling_interval=self.polling_interval,
),
Expand Down
104 changes: 95 additions & 9 deletions astronomer/providers/google/cloud/triggers/dataproc.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
import asyncio
import time
from abc import ABC
from typing import Any, AsyncIterator, Dict, Optional, Sequence, Tuple
from typing import Any, AsyncIterator, Dict, Optional, Sequence, Tuple, Union

from airflow.exceptions import AirflowException
from airflow.providers.google.cloud.hooks.dataproc import DataprocHook
from airflow.triggers.base import BaseTrigger, TriggerEvent
from google.api_core.exceptions import NotFound
from google.cloud.dataproc_v1 import Cluster
from google.cloud.dataproc_v1.types import JobStatus
from google.cloud.dataproc_v1.types import JobStatus, clusters

from astronomer.providers.google.cloud.hooks.dataproc import DataprocHookAsync

Expand All @@ -32,6 +34,9 @@ def __init__(
cluster_name: str,
end_time: float,
metadata: Sequence[Tuple[str, str]] = (),
delete_on_error: bool = True,
cluster_config: Optional[Union[Dict[str, Any], clusters.Cluster]] = None,
labels: Optional[Dict[str, str]] = None,
gcp_conn_id: str = "google_cloud_default",
polling_interval: float = 5.0,
**kwargs: Any,
Expand All @@ -42,6 +47,9 @@ def __init__(
self.cluster_name = cluster_name
self.end_time = end_time
self.metadata = metadata
self.delete_on_error = delete_on_error
self.cluster_config = cluster_config
self.labels = labels
self.gcp_conn_id = gcp_conn_id
self.polling_interval = polling_interval

Expand All @@ -55,22 +63,19 @@ def serialize(self) -> Tuple[str, Dict[str, Any]]:
"cluster_name": self.cluster_name,
"end_time": self.end_time,
"metadata": self.metadata,
"delete_on_error": self.delete_on_error,
"cluster_config": self.cluster_config,
"labels": self.labels,
"gcp_conn_id": self.gcp_conn_id,
"polling_interval": self.polling_interval,
},
)

async def run(self) -> AsyncIterator["TriggerEvent"]: # type: ignore[override]
"""Check the status of cluster until reach the terminal state"""
hook = DataprocHookAsync(gcp_conn_id=self.gcp_conn_id)
while self.end_time > time.monotonic():
try:
cluster = await hook.get_cluster(
region=self.region, # type: ignore[arg-type]
cluster_name=self.cluster_name,
project_id=self.project_id, # type: ignore[arg-type]
metadata=self.metadata,
)
cluster = await self._get_cluster()
if cluster.status.state == cluster.status.State.RUNNING:
yield TriggerEvent(
{
Expand All @@ -79,6 +84,11 @@ async def run(self) -> AsyncIterator["TriggerEvent"]: # type: ignore[override]
"cluster_name": self.cluster_name,
}
)
elif cluster.status.state == cluster.status.State.DELETING:
await self._wait_for_deleting()
self._create_cluster()
await self._handle_error(cluster)

self.log.info(
"Cluster status is %s. Sleeping for %s seconds.",
cluster.status.state,
Expand All @@ -90,6 +100,82 @@ async def run(self) -> AsyncIterator["TriggerEvent"]: # type: ignore[override]

yield TriggerEvent({"status": "error", "message": "Timeout"})

async def _handle_error(self, cluster: clusters.Cluster) -> None:
if cluster.status.state != cluster.status.State.ERROR:
return
self.log.info("Cluster is in ERROR state")
gcs_uri = self._diagnose_cluster()
self.log.info("Diagnostic information for cluster %s available at: %s", self.cluster_name, gcs_uri)
if self.delete_on_error:
self._delete_cluster()
await self._wait_for_deleting()
raise AirflowException(
"Cluster was created but was in ERROR state. \n"
" Diagnostic information for cluster %s available at: %s",
self.cluster_name,
gcs_uri,
)
raise AirflowException(
"Cluster was created but is in ERROR state. \n "
"Diagnostic information for cluster %s available at: %s",
self.cluster_name,
gcs_uri,
)

def _delete_cluster(self) -> None:
hook = DataprocHook(gcp_conn_id=self.gcp_conn_id)
hook.delete_cluster(
project_id=self.project_id,
region=self.region,
cluster_name=self.cluster_name,
metadata=self.metadata,
)

async def _wait_for_deleting(self) -> None:
while self.end_time > time.monotonic():
try:
cluster = await self._get_cluster()
if cluster.status.State.DELETING:
self.log.info(
"Cluster status is %s. Sleeping for %s seconds.",
cluster.status.state,
self.polling_interval,
)
await asyncio.sleep(self.polling_interval)
except NotFound:
return
except Exception as e:
raise e

def _create_cluster(self) -> Any:
hook = DataprocHook(gcp_conn_id=self.gcp_conn_id)
return hook.create_cluster(
region=self.region,
project_id=self.project_id,
cluster_name=self.cluster_name,
cluster_config=self.cluster_config,
labels=self.labels,
metadata=self.metadata,
)

async def _get_cluster(self) -> clusters.Cluster:
hook = DataprocHookAsync(gcp_conn_id=self.gcp_conn_id)
return await hook.get_cluster(
region=self.region, # type: ignore[arg-type]
cluster_name=self.cluster_name,
project_id=self.project_id, # type: ignore[arg-type]
metadata=self.metadata,
)

def _diagnose_cluster(self) -> Any:
hook = DataprocHook(gcp_conn_id=self.gcp_conn_id)
return hook.diagnose_cluster(
project_id=self.project_id,
region=self.region,
cluster_name=self.cluster_name,
metadata=self.metadata,
)


class DataprocDeleteClusterTrigger(BaseTrigger, ABC):
"""
Expand Down
61 changes: 60 additions & 1 deletion tests/google/cloud/operators/test_dataproc.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
from unittest import mock

import pendulum
import pytest
from airflow.exceptions import AirflowException, TaskDeferred
from airflow.models import DAG, DagRun, TaskInstance
from airflow.utils import timezone
from google.api_core.exceptions import AlreadyExists
from google.cloud import dataproc
from google.cloud.dataproc_v1 import Cluster

Expand Down Expand Up @@ -42,6 +46,24 @@ def context():
yield context


def create_context(task):
dag = DAG(dag_id="dag")
tzinfo = pendulum.timezone("Europe/Amsterdam")
execution_date = timezone.datetime(2016, 1, 1, 1, 0, 0, tzinfo=tzinfo)
dag_run = DagRun(dag_id=dag.dag_id, execution_date=execution_date)
task_instance = TaskInstance(task=task)
task_instance.dag_run = dag_run
task_instance.xcom_push = mock.Mock()
return {
"dag": dag,
"ts": execution_date.isoformat(),
"task": task,
"ti": task_instance,
"task_instance": task_instance,
"run_id": dag_run.run_id,
}


@mock.patch("airflow.providers.google.cloud.operators.dataproc.DataprocHook.create_cluster")
def test_dataproc_operator_create_cluster_execute_async(mock_create_cluster):
"""
Expand All @@ -57,7 +79,44 @@ def test_dataproc_operator_create_cluster_execute_async(mock_create_cluster):
task_id="task-id", cluster_name="test_cluster", region=TEST_REGION, project_id=TEST_PROJECT_ID
)
with pytest.raises(TaskDeferred) as exc:
task.execute(context)
task.execute(create_context(task))
assert isinstance(
exc.value.trigger, DataprocCreateClusterTrigger
), "Trigger is not a DataprocCreateClusterTrigger"


@mock.patch("airflow.providers.google.cloud.operators.dataproc.DataprocHook.create_cluster")
def test_dataproc_operator_create_cluster_execute_async_cluster_exist_exception(mock_create_cluster):
"""
Asserts that a task will raise exception when dataproc cluster already exist
and use_if_exists param is False
"""
mock_create_cluster.side_effect = AlreadyExists("Cluster already exist")

task = DataprocCreateClusterOperatorAsync(
task_id="task-id",
cluster_name="test_cluster",
region=TEST_REGION,
project_id=TEST_PROJECT_ID,
use_if_exists=False,
)
with pytest.raises(AlreadyExists):
task.execute(create_context(task))


@mock.patch("airflow.providers.google.cloud.operators.dataproc.DataprocHook.create_cluster")
def test_dataproc_operator_create_cluster_execute_async_cluster_exist(mock_create_cluster):
"""
Asserts that a task is deferred and a DataprocCreateClusterTrigger will be fired
when the DataprocCreateClusterOperatorAsync is executed when dataproc cluster already exist.
"""
mock_create_cluster.return_value = AlreadyExists("Cluster already exist")

task = DataprocCreateClusterOperatorAsync(
task_id="task-id", cluster_name="test_cluster", region=TEST_REGION, project_id=TEST_PROJECT_ID
)
with pytest.raises(TaskDeferred) as exc:
task.execute(create_context(task))
assert isinstance(
exc.value.trigger, DataprocCreateClusterTrigger
), "Trigger is not a DataprocCreateClusterTrigger"
Expand Down
Loading

0 comments on commit e19c656

Please sign in to comment.