Skip to content

Commit

Permalink
Implement DbtCloudJobRunOperatorAsync and `DbtCloudJobRunSensorAsy…
Browse files Browse the repository at this point in the history
…nc` (#623)
  • Loading branch information
bharanidharan14 authored Sep 12, 2022
1 parent 08a9943 commit a07a98e
Show file tree
Hide file tree
Showing 28 changed files with 995 additions and 0 deletions.
4 changes: 4 additions & 0 deletions README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,10 @@ Extras
- ``pip install 'astronomer-providers[databricks]'``
- Databricks

* - ``dbt.cloud``
- ``pip install 'astronomer-providers[dbt.cloud]'``
- Dbt Cloud

* - ``google``
- ``pip install 'astronomer-providers[google]'``
- Google
Expand Down
Empty file.
Empty file.
Empty file.
61 changes: 61 additions & 0 deletions astronomer/providers/dbt/cloud/example_dags/example_dbt_cloud.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
"""Example use of DBTCloudAsync related providers."""

import os
from datetime import timedelta

from airflow import DAG
from airflow.operators.empty import EmptyOperator
from airflow.providers.dbt.cloud.operators.dbt import DbtCloudRunJobOperator
from airflow.utils.timezone import datetime

from astronomer.providers.dbt.cloud.operators.dbt import DbtCloudRunJobOperatorAsync
from astronomer.providers.dbt.cloud.sensors.dbt import DbtCloudJobRunSensorAsync

DBT_CLOUD_CONN_ID = os.getenv("ASTRO_DBT_CLOUD_CONN", "dbt_cloud_default")
DBT_CLOUD_ACCOUNT_ID = os.getenv("ASTRO_DBT_CLOUD_ACCOUNT_ID", 12345)
DBT_CLOUD_JOB_ID = int(os.getenv("ASTRO_DBT_CLOUD_JOB_ID", 12345))
EXECUTION_TIMEOUT = int(os.getenv("EXECUTION_TIMEOUT", 6))


default_args = {
"execution_timeout": timedelta(hours=EXECUTION_TIMEOUT),
"dbt_cloud_conn_id": DBT_CLOUD_CONN_ID,
"account_id": DBT_CLOUD_ACCOUNT_ID,
"retries": int(os.getenv("DEFAULT_TASK_RETRIES", 2)),
"retry_delay": timedelta(seconds=int(os.getenv("DEFAULT_RETRY_DELAY_SECONDS", 60))),
}

with DAG(
dag_id="example_dbt_cloud",
start_date=datetime(2022, 1, 1),
schedule_interval=None,
default_args=default_args,
tags=["example", "async", "dbt-cloud"],
catchup=False,
) as dag:
start = EmptyOperator(task_id="start")
end = EmptyOperator(task_id="end")
# [START howto_operator_dbt_cloud_run_job_async]
trigger_dbt_job_run_async = DbtCloudRunJobOperatorAsync(
task_id="trigger_dbt_job_run_async",
job_id=DBT_CLOUD_JOB_ID,
check_interval=10,
timeout=300,
)
# [END howto_operator_dbt_cloud_run_job_async]

trigger_job_run2 = DbtCloudRunJobOperator(
task_id="trigger_job_run2",
job_id=DBT_CLOUD_JOB_ID,
wait_for_termination=False,
additional_run_config={"threads_override": 8},
)

# [START howto_operator_dbt_cloud_run_job_sensor_async]
job_run_sensor_async = DbtCloudJobRunSensorAsync(
task_id="job_run_sensor_async", run_id=trigger_job_run2.output, timeout=20
)
# [END howto_operator_dbt_cloud_run_job_sensor_async]

start >> trigger_dbt_job_run_async >> end
start >> trigger_job_run2 >> job_run_sensor_async >> end
Empty file.
134 changes: 134 additions & 0 deletions astronomer/providers/dbt/cloud/hooks/dbt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
from functools import wraps
from inspect import signature
from typing import Any, Dict, List, Optional, Tuple, TypeVar, cast

import aiohttp
from aiohttp import ClientResponseError
from airflow import AirflowException
from airflow.hooks.base import BaseHook
from airflow.models import Connection
from asgiref.sync import sync_to_async

from astronomer.providers.package import get_provider_info

T = TypeVar("T", bound=Any)


def provide_account_id(func: T) -> T:
"""
Decorator which provides a fallback value for ``account_id``. If the ``account_id`` is None or not passed
to the decorated function, the value will be taken from the configured dbt Cloud Airflow Connection.
"""
function_signature = signature(func)

@wraps(func)
async def wrapper(*args: Any, **kwargs: Any) -> Any:
bound_args = function_signature.bind(*args, **kwargs)

if bound_args.arguments.get("account_id") is None:
self = args[0]
if self.dbt_cloud_conn_id:
connection = await sync_to_async(self.get_connection)(self.dbt_cloud_conn_id)
default_account_id = connection.login
if not default_account_id:
raise AirflowException("Could not determine the dbt Cloud account.")
bound_args.arguments["account_id"] = int(default_account_id)

return await func(*bound_args.args, **bound_args.kwargs)

return cast(T, wrapper)


class DbtCloudHookAsync(BaseHook):
"""
Interact with dbt Cloud using the V2 API.
:param dbt_cloud_conn_id: The ID of the :ref:`dbt Cloud connection <howto/connection:dbt-cloud>`.
"""

conn_name_attr = "dbt_cloud_conn_id"
default_conn_name = "dbt_cloud_default"
conn_type = "dbt_cloud"
hook_name = "dbt Cloud"

def __init__(self, dbt_cloud_conn_id: str):
self.dbt_cloud_conn_id = dbt_cloud_conn_id

async def get_headers_tenants_from_connection(self) -> Tuple[Dict[str, Any], str]:
"""Get Headers, tenants from the connection details"""
headers: Dict[str, Any] = {}
connection: Connection = await sync_to_async(self.get_connection)(self.dbt_cloud_conn_id)
tenant: str = connection.schema if connection.schema else "cloud"
provider_info = get_provider_info()
package_name = provider_info["package-name"]
version = provider_info["versions"]
headers["User-Agent"] = f"{package_name}-v{version}"
headers["Content-Type"] = "application/json"
headers["Authorization"] = f"Token {connection.password}"
return headers, tenant

@staticmethod
def get_request_url_params(
tenant: str, endpoint: str, include_related: Optional[List[str]] = None
) -> Tuple[str, Dict[str, Any]]:
"""
Form URL from base url and endpoint url
:param tenant: The tenant name which is need to be replaced in base url.
:param endpoint: Endpoint url to be requested.
:param include_related: Optional. List of related fields to pull with the run.
Valid values are "trigger", "job", "repository", and "environment".
"""
data: Dict[str, Any] = {}
base_url = f"https://{tenant}.getdbt.com/api/v2/accounts/"
if include_related:
data = {"include_related": include_related}
if base_url and not base_url.endswith("/") and endpoint and not endpoint.startswith("/"):
url = base_url + "/" + endpoint
else:
url = (base_url or "") + (endpoint or "")
return url, data

@provide_account_id
async def get_job_details(
self, run_id: int, account_id: Optional[int] = None, include_related: Optional[List[str]] = None
) -> Any:
"""
Uses Http async call to retrieve metadata for a specific run of a dbt Cloud job.
:param run_id: The ID of a dbt Cloud job run.
:param account_id: Optional. The ID of a dbt Cloud account.
:param include_related: Optional. List of related fields to pull with the run.
Valid values are "trigger", "job", "repository", and "environment".
"""
endpoint = f"{account_id}/runs/{run_id}/"
headers, tenant = await self.get_headers_tenants_from_connection()
url, params = self.get_request_url_params(tenant, endpoint, include_related)
async with aiohttp.ClientSession(headers=headers) as session:
async with session.get(url, params=params) as response:
try:
response.raise_for_status()
return await response.json()
except ClientResponseError as e:
raise AirflowException(str(e.status) + ":" + e.message)

async def get_job_status(
self, run_id: int, account_id: Optional[int] = None, include_related: Optional[List[str]] = None
) -> int:
"""
Retrieves the status for a specific run of a dbt Cloud job.
:param run_id: The ID of a dbt Cloud job run.
:param account_id: Optional. The ID of a dbt Cloud account.
:param include_related: Optional. List of related fields to pull with the run.
Valid values are "trigger", "job", "repository", and "environment".
"""
try:
self.log.info("Getting the status of job run %s.", str(run_id))
response = await self.get_job_details(
run_id, account_id=account_id, include_related=include_related
)
job_run_status: int = response["data"]["status"]
return job_run_status
except Exception as e:
raise e
Empty file.
79 changes: 79 additions & 0 deletions astronomer/providers/dbt/cloud/operators/dbt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
import time
from typing import TYPE_CHECKING, Any, Dict

from airflow import AirflowException
from airflow.providers.dbt.cloud.hooks.dbt import DbtCloudHook
from airflow.providers.dbt.cloud.operators.dbt import DbtCloudRunJobOperator

from astronomer.providers.dbt.cloud.triggers.dbt import DbtCloudRunJobTrigger

if TYPE_CHECKING: # pragma: no cover
from airflow.utils.context import Context


class DbtCloudRunJobOperatorAsync(DbtCloudRunJobOperator):
"""
Executes a dbt Cloud job asynchronously. Trigger the dbt cloud job via worker to dbt and with run id in response
poll for the status in trigger.
.. seealso::
For more information on sync Operator DbtCloudRunJobOperator, take a look at the guide:
:ref:`howto/operator:DbtCloudRunJobOperator`
:param dbt_cloud_conn_id: The connection ID for connecting to dbt Cloud.
:param job_id: The ID of a dbt Cloud job.
:param account_id: Optional. The ID of a dbt Cloud account.
:param trigger_reason: Optional Description of the reason to trigger the job. Dbt requires the trigger reason while
making an API. if it is not provided uses the default reasons.
:param steps_override: Optional. List of dbt commands to execute when triggering the job instead of those
configured in dbt Cloud.
:param schema_override: Optional. Override the destination schema in the configured target for this job.
:param timeout: Time in seconds to wait for a job run to reach a terminal status. Defaults to 7 days.
:param check_interval: Time in seconds to check on a job run's status. Defaults to 60 seconds.
:param additional_run_config: Optional. Any additional parameters that should be included in the API
request when triggering the job.
:return: The ID of the triggered dbt Cloud job run.
"""

def execute(self, context: "Context") -> None: # type: ignore[override]
"""Submits a job which generates a run_id and gets deferred"""
if self.trigger_reason is None:
self.trigger_reason = (
f"Triggered via Apache Airflow by task {self.task_id!r} in the {self.dag.dag_id} DAG."
)
hook = DbtCloudHook(dbt_cloud_conn_id=self.dbt_cloud_conn_id)
trigger_job_response = hook.trigger_job_run(
account_id=self.account_id,
job_id=self.job_id,
cause=self.trigger_reason,
steps_override=self.steps_override,
schema_override=self.schema_override,
additional_run_config=self.additional_run_config,
)
run_id = trigger_job_response.json()["data"]["id"]
job_run_url = trigger_job_response.json()["data"]["href"]

context["ti"].xcom_push(key="job_run_url", value=job_run_url)
end_time = time.time() + self.timeout
self.defer(
timeout=self.execution_timeout,
trigger=DbtCloudRunJobTrigger(
conn_id=self.dbt_cloud_conn_id,
run_id=run_id,
end_time=end_time,
account_id=self.account_id,
poll_interval=self.check_interval,
),
method_name="execute_complete",
)

def execute_complete(self, context: "Context", event: Dict[str, Any]) -> int:
"""
Callback for when the trigger fires - returns immediately.
Relies on trigger to throw an exception, otherwise it assumes execution was
successful.
"""
if event["status"] == "error":
raise AirflowException(event["message"])
self.log.info(event["message"])
return int(event["run_id"])
Empty file.
62 changes: 62 additions & 0 deletions astronomer/providers/dbt/cloud/sensors/dbt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
import time
from typing import TYPE_CHECKING, Any, Dict

from airflow import AirflowException
from airflow.providers.dbt.cloud.sensors.dbt import DbtCloudJobRunSensor

from astronomer.providers.dbt.cloud.triggers.dbt import DbtCloudRunJobTrigger

if TYPE_CHECKING: # pragma: no cover
from airflow.utils.context import Context


class DbtCloudJobRunSensorAsync(DbtCloudJobRunSensor):
"""
Checks the status of a dbt Cloud job run.
.. seealso::
For more information on sync Sensor DbtCloudJobRunSensor, take a look at the guide::
:ref:`howto/operator:DbtCloudJobRunSensor`
:param dbt_cloud_conn_id: The connection identifier for connecting to dbt Cloud.
:param run_id: The job run identifier.
:param account_id: The dbt Cloud account identifier.
:param timeout: Time in seconds to wait for a job run to reach a terminal status. Defaults to 7 days.
"""

def __init__(
self,
*,
poll_interval: float = 5,
timeout: float = 60 * 60 * 24 * 7,
**kwargs: Any,
):
self.poll_interval = poll_interval
self.timeout = timeout
super().__init__(**kwargs)

def execute(self, context: "Context") -> None:
"""Defers trigger class to poll for state of the job run until it reaches a failure state or success state"""
end_time = time.time() + self.timeout
self.defer(
timeout=self.execution_timeout,
trigger=DbtCloudRunJobTrigger(
run_id=self.run_id,
conn_id=self.dbt_cloud_conn_id,
account_id=self.account_id,
poll_interval=self.poll_interval,
end_time=end_time,
),
method_name="execute_complete",
)

def execute_complete(self, context: "Context", event: Dict[str, Any]) -> int:
"""
Callback for when the trigger fires - returns immediately.
Relies on trigger to throw an exception, otherwise it assumes execution was
successful.
"""
if event["status"] in ["error", "cancelled"]:
raise AirflowException(event["message"])
self.log.info(event["message"])
return int(event["run_id"])
Empty file.
Loading

0 comments on commit a07a98e

Please sign in to comment.