-
Notifications
You must be signed in to change notification settings - Fork 26
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Implement
DbtCloudJobRunOperatorAsync
and `DbtCloudJobRunSensorAsy…
…nc` (#623)
- Loading branch information
1 parent
08a9943
commit a07a98e
Showing
28 changed files
with
995 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
Empty file.
Empty file.
61 changes: 61 additions & 0 deletions
61
astronomer/providers/dbt/cloud/example_dags/example_dbt_cloud.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
"""Example use of DBTCloudAsync related providers.""" | ||
|
||
import os | ||
from datetime import timedelta | ||
|
||
from airflow import DAG | ||
from airflow.operators.empty import EmptyOperator | ||
from airflow.providers.dbt.cloud.operators.dbt import DbtCloudRunJobOperator | ||
from airflow.utils.timezone import datetime | ||
|
||
from astronomer.providers.dbt.cloud.operators.dbt import DbtCloudRunJobOperatorAsync | ||
from astronomer.providers.dbt.cloud.sensors.dbt import DbtCloudJobRunSensorAsync | ||
|
||
DBT_CLOUD_CONN_ID = os.getenv("ASTRO_DBT_CLOUD_CONN", "dbt_cloud_default") | ||
DBT_CLOUD_ACCOUNT_ID = os.getenv("ASTRO_DBT_CLOUD_ACCOUNT_ID", 12345) | ||
DBT_CLOUD_JOB_ID = int(os.getenv("ASTRO_DBT_CLOUD_JOB_ID", 12345)) | ||
EXECUTION_TIMEOUT = int(os.getenv("EXECUTION_TIMEOUT", 6)) | ||
|
||
|
||
default_args = { | ||
"execution_timeout": timedelta(hours=EXECUTION_TIMEOUT), | ||
"dbt_cloud_conn_id": DBT_CLOUD_CONN_ID, | ||
"account_id": DBT_CLOUD_ACCOUNT_ID, | ||
"retries": int(os.getenv("DEFAULT_TASK_RETRIES", 2)), | ||
"retry_delay": timedelta(seconds=int(os.getenv("DEFAULT_RETRY_DELAY_SECONDS", 60))), | ||
} | ||
|
||
with DAG( | ||
dag_id="example_dbt_cloud", | ||
start_date=datetime(2022, 1, 1), | ||
schedule_interval=None, | ||
default_args=default_args, | ||
tags=["example", "async", "dbt-cloud"], | ||
catchup=False, | ||
) as dag: | ||
start = EmptyOperator(task_id="start") | ||
end = EmptyOperator(task_id="end") | ||
# [START howto_operator_dbt_cloud_run_job_async] | ||
trigger_dbt_job_run_async = DbtCloudRunJobOperatorAsync( | ||
task_id="trigger_dbt_job_run_async", | ||
job_id=DBT_CLOUD_JOB_ID, | ||
check_interval=10, | ||
timeout=300, | ||
) | ||
# [END howto_operator_dbt_cloud_run_job_async] | ||
|
||
trigger_job_run2 = DbtCloudRunJobOperator( | ||
task_id="trigger_job_run2", | ||
job_id=DBT_CLOUD_JOB_ID, | ||
wait_for_termination=False, | ||
additional_run_config={"threads_override": 8}, | ||
) | ||
|
||
# [START howto_operator_dbt_cloud_run_job_sensor_async] | ||
job_run_sensor_async = DbtCloudJobRunSensorAsync( | ||
task_id="job_run_sensor_async", run_id=trigger_job_run2.output, timeout=20 | ||
) | ||
# [END howto_operator_dbt_cloud_run_job_sensor_async] | ||
|
||
start >> trigger_dbt_job_run_async >> end | ||
start >> trigger_job_run2 >> job_run_sensor_async >> end |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,134 @@ | ||
from functools import wraps | ||
from inspect import signature | ||
from typing import Any, Dict, List, Optional, Tuple, TypeVar, cast | ||
|
||
import aiohttp | ||
from aiohttp import ClientResponseError | ||
from airflow import AirflowException | ||
from airflow.hooks.base import BaseHook | ||
from airflow.models import Connection | ||
from asgiref.sync import sync_to_async | ||
|
||
from astronomer.providers.package import get_provider_info | ||
|
||
T = TypeVar("T", bound=Any) | ||
|
||
|
||
def provide_account_id(func: T) -> T: | ||
""" | ||
Decorator which provides a fallback value for ``account_id``. If the ``account_id`` is None or not passed | ||
to the decorated function, the value will be taken from the configured dbt Cloud Airflow Connection. | ||
""" | ||
function_signature = signature(func) | ||
|
||
@wraps(func) | ||
async def wrapper(*args: Any, **kwargs: Any) -> Any: | ||
bound_args = function_signature.bind(*args, **kwargs) | ||
|
||
if bound_args.arguments.get("account_id") is None: | ||
self = args[0] | ||
if self.dbt_cloud_conn_id: | ||
connection = await sync_to_async(self.get_connection)(self.dbt_cloud_conn_id) | ||
default_account_id = connection.login | ||
if not default_account_id: | ||
raise AirflowException("Could not determine the dbt Cloud account.") | ||
bound_args.arguments["account_id"] = int(default_account_id) | ||
|
||
return await func(*bound_args.args, **bound_args.kwargs) | ||
|
||
return cast(T, wrapper) | ||
|
||
|
||
class DbtCloudHookAsync(BaseHook): | ||
""" | ||
Interact with dbt Cloud using the V2 API. | ||
:param dbt_cloud_conn_id: The ID of the :ref:`dbt Cloud connection <howto/connection:dbt-cloud>`. | ||
""" | ||
|
||
conn_name_attr = "dbt_cloud_conn_id" | ||
default_conn_name = "dbt_cloud_default" | ||
conn_type = "dbt_cloud" | ||
hook_name = "dbt Cloud" | ||
|
||
def __init__(self, dbt_cloud_conn_id: str): | ||
self.dbt_cloud_conn_id = dbt_cloud_conn_id | ||
|
||
async def get_headers_tenants_from_connection(self) -> Tuple[Dict[str, Any], str]: | ||
"""Get Headers, tenants from the connection details""" | ||
headers: Dict[str, Any] = {} | ||
connection: Connection = await sync_to_async(self.get_connection)(self.dbt_cloud_conn_id) | ||
tenant: str = connection.schema if connection.schema else "cloud" | ||
provider_info = get_provider_info() | ||
package_name = provider_info["package-name"] | ||
version = provider_info["versions"] | ||
headers["User-Agent"] = f"{package_name}-v{version}" | ||
headers["Content-Type"] = "application/json" | ||
headers["Authorization"] = f"Token {connection.password}" | ||
return headers, tenant | ||
|
||
@staticmethod | ||
def get_request_url_params( | ||
tenant: str, endpoint: str, include_related: Optional[List[str]] = None | ||
) -> Tuple[str, Dict[str, Any]]: | ||
""" | ||
Form URL from base url and endpoint url | ||
:param tenant: The tenant name which is need to be replaced in base url. | ||
:param endpoint: Endpoint url to be requested. | ||
:param include_related: Optional. List of related fields to pull with the run. | ||
Valid values are "trigger", "job", "repository", and "environment". | ||
""" | ||
data: Dict[str, Any] = {} | ||
base_url = f"https://{tenant}.getdbt.com/api/v2/accounts/" | ||
if include_related: | ||
data = {"include_related": include_related} | ||
if base_url and not base_url.endswith("/") and endpoint and not endpoint.startswith("/"): | ||
url = base_url + "/" + endpoint | ||
else: | ||
url = (base_url or "") + (endpoint or "") | ||
return url, data | ||
|
||
@provide_account_id | ||
async def get_job_details( | ||
self, run_id: int, account_id: Optional[int] = None, include_related: Optional[List[str]] = None | ||
) -> Any: | ||
""" | ||
Uses Http async call to retrieve metadata for a specific run of a dbt Cloud job. | ||
:param run_id: The ID of a dbt Cloud job run. | ||
:param account_id: Optional. The ID of a dbt Cloud account. | ||
:param include_related: Optional. List of related fields to pull with the run. | ||
Valid values are "trigger", "job", "repository", and "environment". | ||
""" | ||
endpoint = f"{account_id}/runs/{run_id}/" | ||
headers, tenant = await self.get_headers_tenants_from_connection() | ||
url, params = self.get_request_url_params(tenant, endpoint, include_related) | ||
async with aiohttp.ClientSession(headers=headers) as session: | ||
async with session.get(url, params=params) as response: | ||
try: | ||
response.raise_for_status() | ||
return await response.json() | ||
except ClientResponseError as e: | ||
raise AirflowException(str(e.status) + ":" + e.message) | ||
|
||
async def get_job_status( | ||
self, run_id: int, account_id: Optional[int] = None, include_related: Optional[List[str]] = None | ||
) -> int: | ||
""" | ||
Retrieves the status for a specific run of a dbt Cloud job. | ||
:param run_id: The ID of a dbt Cloud job run. | ||
:param account_id: Optional. The ID of a dbt Cloud account. | ||
:param include_related: Optional. List of related fields to pull with the run. | ||
Valid values are "trigger", "job", "repository", and "environment". | ||
""" | ||
try: | ||
self.log.info("Getting the status of job run %s.", str(run_id)) | ||
response = await self.get_job_details( | ||
run_id, account_id=account_id, include_related=include_related | ||
) | ||
job_run_status: int = response["data"]["status"] | ||
return job_run_status | ||
except Exception as e: | ||
raise e |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,79 @@ | ||
import time | ||
from typing import TYPE_CHECKING, Any, Dict | ||
|
||
from airflow import AirflowException | ||
from airflow.providers.dbt.cloud.hooks.dbt import DbtCloudHook | ||
from airflow.providers.dbt.cloud.operators.dbt import DbtCloudRunJobOperator | ||
|
||
from astronomer.providers.dbt.cloud.triggers.dbt import DbtCloudRunJobTrigger | ||
|
||
if TYPE_CHECKING: # pragma: no cover | ||
from airflow.utils.context import Context | ||
|
||
|
||
class DbtCloudRunJobOperatorAsync(DbtCloudRunJobOperator): | ||
""" | ||
Executes a dbt Cloud job asynchronously. Trigger the dbt cloud job via worker to dbt and with run id in response | ||
poll for the status in trigger. | ||
.. seealso:: | ||
For more information on sync Operator DbtCloudRunJobOperator, take a look at the guide: | ||
:ref:`howto/operator:DbtCloudRunJobOperator` | ||
:param dbt_cloud_conn_id: The connection ID for connecting to dbt Cloud. | ||
:param job_id: The ID of a dbt Cloud job. | ||
:param account_id: Optional. The ID of a dbt Cloud account. | ||
:param trigger_reason: Optional Description of the reason to trigger the job. Dbt requires the trigger reason while | ||
making an API. if it is not provided uses the default reasons. | ||
:param steps_override: Optional. List of dbt commands to execute when triggering the job instead of those | ||
configured in dbt Cloud. | ||
:param schema_override: Optional. Override the destination schema in the configured target for this job. | ||
:param timeout: Time in seconds to wait for a job run to reach a terminal status. Defaults to 7 days. | ||
:param check_interval: Time in seconds to check on a job run's status. Defaults to 60 seconds. | ||
:param additional_run_config: Optional. Any additional parameters that should be included in the API | ||
request when triggering the job. | ||
:return: The ID of the triggered dbt Cloud job run. | ||
""" | ||
|
||
def execute(self, context: "Context") -> None: # type: ignore[override] | ||
"""Submits a job which generates a run_id and gets deferred""" | ||
if self.trigger_reason is None: | ||
self.trigger_reason = ( | ||
f"Triggered via Apache Airflow by task {self.task_id!r} in the {self.dag.dag_id} DAG." | ||
) | ||
hook = DbtCloudHook(dbt_cloud_conn_id=self.dbt_cloud_conn_id) | ||
trigger_job_response = hook.trigger_job_run( | ||
account_id=self.account_id, | ||
job_id=self.job_id, | ||
cause=self.trigger_reason, | ||
steps_override=self.steps_override, | ||
schema_override=self.schema_override, | ||
additional_run_config=self.additional_run_config, | ||
) | ||
run_id = trigger_job_response.json()["data"]["id"] | ||
job_run_url = trigger_job_response.json()["data"]["href"] | ||
|
||
context["ti"].xcom_push(key="job_run_url", value=job_run_url) | ||
end_time = time.time() + self.timeout | ||
self.defer( | ||
timeout=self.execution_timeout, | ||
trigger=DbtCloudRunJobTrigger( | ||
conn_id=self.dbt_cloud_conn_id, | ||
run_id=run_id, | ||
end_time=end_time, | ||
account_id=self.account_id, | ||
poll_interval=self.check_interval, | ||
), | ||
method_name="execute_complete", | ||
) | ||
|
||
def execute_complete(self, context: "Context", event: Dict[str, Any]) -> int: | ||
""" | ||
Callback for when the trigger fires - returns immediately. | ||
Relies on trigger to throw an exception, otherwise it assumes execution was | ||
successful. | ||
""" | ||
if event["status"] == "error": | ||
raise AirflowException(event["message"]) | ||
self.log.info(event["message"]) | ||
return int(event["run_id"]) |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,62 @@ | ||
import time | ||
from typing import TYPE_CHECKING, Any, Dict | ||
|
||
from airflow import AirflowException | ||
from airflow.providers.dbt.cloud.sensors.dbt import DbtCloudJobRunSensor | ||
|
||
from astronomer.providers.dbt.cloud.triggers.dbt import DbtCloudRunJobTrigger | ||
|
||
if TYPE_CHECKING: # pragma: no cover | ||
from airflow.utils.context import Context | ||
|
||
|
||
class DbtCloudJobRunSensorAsync(DbtCloudJobRunSensor): | ||
""" | ||
Checks the status of a dbt Cloud job run. | ||
.. seealso:: | ||
For more information on sync Sensor DbtCloudJobRunSensor, take a look at the guide:: | ||
:ref:`howto/operator:DbtCloudJobRunSensor` | ||
:param dbt_cloud_conn_id: The connection identifier for connecting to dbt Cloud. | ||
:param run_id: The job run identifier. | ||
:param account_id: The dbt Cloud account identifier. | ||
:param timeout: Time in seconds to wait for a job run to reach a terminal status. Defaults to 7 days. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
*, | ||
poll_interval: float = 5, | ||
timeout: float = 60 * 60 * 24 * 7, | ||
**kwargs: Any, | ||
): | ||
self.poll_interval = poll_interval | ||
self.timeout = timeout | ||
super().__init__(**kwargs) | ||
|
||
def execute(self, context: "Context") -> None: | ||
"""Defers trigger class to poll for state of the job run until it reaches a failure state or success state""" | ||
end_time = time.time() + self.timeout | ||
self.defer( | ||
timeout=self.execution_timeout, | ||
trigger=DbtCloudRunJobTrigger( | ||
run_id=self.run_id, | ||
conn_id=self.dbt_cloud_conn_id, | ||
account_id=self.account_id, | ||
poll_interval=self.poll_interval, | ||
end_time=end_time, | ||
), | ||
method_name="execute_complete", | ||
) | ||
|
||
def execute_complete(self, context: "Context", event: Dict[str, Any]) -> int: | ||
""" | ||
Callback for when the trigger fires - returns immediately. | ||
Relies on trigger to throw an exception, otherwise it assumes execution was | ||
successful. | ||
""" | ||
if event["status"] in ["error", "cancelled"]: | ||
raise AirflowException(event["message"]) | ||
self.log.info(event["message"]) | ||
return int(event["run_id"]) |
Empty file.
Oops, something went wrong.