Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

migrate to dbt v3 api for project endpoints #39214

Merged
merged 2 commits into from
Apr 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 14 additions & 11 deletions airflow/providers/dbt/cloud/hooks/dbt.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ async def wrapper(*args: Any, **kwargs: Any) -> Any:

class DbtCloudHook(HttpHook):
"""
Interact with dbt Cloud using the V2 API.
Interact with dbt Cloud using the V2 (V3 if supported) API.

:param dbt_cloud_conn_id: The ID of the :ref:`dbt Cloud connection <howto/connection:dbt-cloud>`.
"""
Expand Down Expand Up @@ -194,7 +194,7 @@ def _get_tenant_domain(conn: Connection) -> str:

@staticmethod
def get_request_url_params(
tenant: str, endpoint: str, include_related: list[str] | None = None
tenant: str, endpoint: str, include_related: list[str] | None = None, *, api_version: str = "v2"
) -> tuple[str, dict[str, Any]]:
"""
Form URL from base url and endpoint url.
Expand All @@ -207,7 +207,7 @@ def get_request_url_params(
data: dict[str, Any] = {}
if include_related:
data = {"include_related": include_related}
url = f"https://{tenant}/api/v2/accounts/{endpoint or ''}"
url = f"https://{tenant}/api/{api_version}/accounts/{endpoint or ''}"
return url, data

async def get_headers_tenants_from_connection(self) -> tuple[dict[str, Any], str]:
Expand Down Expand Up @@ -270,7 +270,7 @@ def connection(self) -> Connection:

def get_conn(self, *args, **kwargs) -> Session:
tenant = self._get_tenant_domain(self.connection)
self.base_url = f"https://{tenant}/api/v2/accounts/"
self.base_url = f"https://{tenant}/"

session = Session()
session.auth = self.auth_type(self.connection.password)
Expand Down Expand Up @@ -298,23 +298,26 @@ def _paginate(self, endpoint: str, payload: dict[str, Any] | None = None) -> lis

def _run_and_get_response(
self,
*,
method: str = "GET",
endpoint: str | None = None,
payload: str | dict[str, Any] | None = None,
paginate: bool = False,
api_version: str = "v2",
) -> Any:
self.method = method
full_endpoint = f"api/{api_version}/accounts/{endpoint}" if endpoint else None

if paginate:
if isinstance(payload, str):
raise ValueError("Payload cannot be a string to paginate a response.")

if endpoint:
return self._paginate(endpoint=endpoint, payload=payload)
else:
raise ValueError("An endpoint is needed to paginate a response.")
if full_endpoint:
return self._paginate(endpoint=full_endpoint, payload=payload)

return self.run(endpoint=endpoint, data=payload)
raise ValueError("An endpoint is needed to paginate a response.")

return self.run(endpoint=full_endpoint, data=payload)

def list_accounts(self) -> list[Response]:
"""
Expand Down Expand Up @@ -342,7 +345,7 @@ def list_projects(self, account_id: int | None = None) -> list[Response]:
:param account_id: Optional. The ID of a dbt Cloud account.
:return: List of request responses.
"""
return self._run_and_get_response(endpoint=f"{account_id}/projects/", paginate=True)
return self._run_and_get_response(endpoint=f"{account_id}/projects/", paginate=True, api_version="v3")

@fallback_to_default_account
def get_project(self, project_id: int, account_id: int | None = None) -> Response:
Expand All @@ -353,7 +356,7 @@ def get_project(self, project_id: int, account_id: int | None = None) -> Respons
:param account_id: Optional. The ID of a dbt Cloud account.
:return: The request response.
"""
return self._run_and_get_response(endpoint=f"{account_id}/projects/{project_id}/")
return self._run_and_get_response(endpoint=f"{account_id}/projects/{project_id}/", api_version="v3")

@fallback_to_default_account
def list_jobs(
Expand Down
51 changes: 30 additions & 21 deletions tests/providers/dbt/cloud/hooks/test_dbt.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,8 @@
JOB_ID = 4444
RUN_ID = 5555

BASE_URL = "https://cloud.getdbt.com/api/v2/accounts/"
SINGLE_TENANT_URL = "https://single.tenant.getdbt.com/api/v2/accounts/"
BASE_URL = "https://cloud.getdbt.com/"
SINGLE_TENANT_URL = "https://single.tenant.getdbt.com/"


class TestDbtCloudJobRunStatus:
Expand Down Expand Up @@ -211,7 +211,7 @@ def test_get_account(self, mock_http_run, mock_paginate, conn_id, account_id):
assert hook.method == "GET"

_account_id = account_id or DEFAULT_ACCOUNT_ID
hook.run.assert_called_once_with(endpoint=f"{_account_id}/", data=None)
hook.run.assert_called_once_with(endpoint=f"api/v2/accounts/{_account_id}/", data=None)
hook._paginate.assert_not_called()

@pytest.mark.parametrize(
Expand All @@ -229,7 +229,9 @@ def test_list_projects(self, mock_http_run, mock_paginate, conn_id, account_id):

_account_id = account_id or DEFAULT_ACCOUNT_ID
hook.run.assert_not_called()
hook._paginate.assert_called_once_with(endpoint=f"{_account_id}/projects/", payload=None)
hook._paginate.assert_called_once_with(
endpoint=f"api/v3/accounts/{_account_id}/projects/", payload=None
)

@pytest.mark.parametrize(
argnames="conn_id, account_id",
Expand All @@ -245,7 +247,9 @@ def test_get_project(self, mock_http_run, mock_paginate, conn_id, account_id):
assert hook.method == "GET"

_account_id = account_id or DEFAULT_ACCOUNT_ID
hook.run.assert_called_once_with(endpoint=f"{_account_id}/projects/{PROJECT_ID}/", data=None)
hook.run.assert_called_once_with(
endpoint=f"api/v3/accounts/{_account_id}/projects/{PROJECT_ID}/", data=None
)
hook._paginate.assert_not_called()

@pytest.mark.parametrize(
Expand All @@ -263,7 +267,7 @@ def test_list_jobs(self, mock_http_run, mock_paginate, conn_id, account_id):

_account_id = account_id or DEFAULT_ACCOUNT_ID
hook._paginate.assert_called_once_with(
endpoint=f"{_account_id}/jobs/", payload={"order_by": None, "project_id": None}
endpoint=f"api/v2/accounts/{_account_id}/jobs/", payload={"order_by": None, "project_id": None}
)
hook.run.assert_not_called()

Expand All @@ -282,7 +286,8 @@ def test_list_jobs_with_payload(self, mock_http_run, mock_paginate, conn_id, acc

_account_id = account_id or DEFAULT_ACCOUNT_ID
hook._paginate.assert_called_once_with(
endpoint=f"{_account_id}/jobs/", payload={"order_by": "-id", "project_id": PROJECT_ID}
endpoint=f"api/v2/accounts/{_account_id}/jobs/",
payload={"order_by": "-id", "project_id": PROJECT_ID},
)
hook.run.assert_not_called()

Expand All @@ -300,7 +305,7 @@ def test_get_job(self, mock_http_run, mock_paginate, conn_id, account_id):
assert hook.method == "GET"

_account_id = account_id or DEFAULT_ACCOUNT_ID
hook.run.assert_called_once_with(endpoint=f"{_account_id}/jobs/{JOB_ID}", data=None)
hook.run.assert_called_once_with(endpoint=f"api/v2/accounts/{_account_id}/jobs/{JOB_ID}", data=None)
hook._paginate.assert_not_called()

@pytest.mark.parametrize(
Expand All @@ -319,7 +324,7 @@ def test_trigger_job_run(self, mock_http_run, mock_paginate, conn_id, account_id

_account_id = account_id or DEFAULT_ACCOUNT_ID
hook.run.assert_called_once_with(
endpoint=f"{_account_id}/jobs/{JOB_ID}/run/",
endpoint=f"api/v2/accounts/{_account_id}/jobs/{JOB_ID}/run/",
data=json.dumps({"cause": cause, "steps_override": None, "schema_override": None}),
)
hook._paginate.assert_not_called()
Expand Down Expand Up @@ -348,7 +353,7 @@ def test_trigger_job_run_with_overrides(self, mock_http_run, mock_paginate, conn

_account_id = account_id or DEFAULT_ACCOUNT_ID
hook.run.assert_called_once_with(
endpoint=f"{_account_id}/jobs/{JOB_ID}/run/",
endpoint=f"api/v2/accounts/{_account_id}/jobs/{JOB_ID}/run/",
data=json.dumps(
{"cause": cause, "steps_override": steps_override, "schema_override": schema_override}
),
Expand Down Expand Up @@ -376,7 +381,7 @@ def test_trigger_job_run_with_additional_run_configs(

_account_id = account_id or DEFAULT_ACCOUNT_ID
hook.run.assert_called_once_with(
endpoint=f"{_account_id}/jobs/{JOB_ID}/run/",
endpoint=f"api/v2/accounts/{_account_id}/jobs/{JOB_ID}/run/",
data=json.dumps(
{
"cause": cause,
Expand Down Expand Up @@ -405,7 +410,7 @@ def test_list_job_runs(self, mock_http_run, mock_paginate, conn_id, account_id):
_account_id = account_id or DEFAULT_ACCOUNT_ID
hook.run.assert_not_called()
hook._paginate.assert_called_once_with(
endpoint=f"{_account_id}/runs/",
endpoint=f"api/v2/accounts/{_account_id}/runs/",
payload={
"include_related": None,
"job_definition_id": None,
Expand All @@ -431,7 +436,7 @@ def test_list_job_runs_with_payload(self, mock_http_run, mock_paginate, conn_id,
_account_id = account_id or DEFAULT_ACCOUNT_ID
hook.run.assert_not_called()
hook._paginate.assert_called_once_with(
endpoint=f"{_account_id}/runs/",
endpoint=f"api/v2/accounts/{_account_id}/runs/",
payload={
"include_related": ["job"],
"job_definition_id": JOB_ID,
Expand All @@ -452,7 +457,7 @@ def test_get_job_runs(self, mock_http_run, conn_id, account_id):
assert hook.method == "GET"

_account_id = account_id or DEFAULT_ACCOUNT_ID
hook.run.assert_called_once_with(endpoint=f"{_account_id}/runs/", data=None)
hook.run.assert_called_once_with(endpoint=f"api/v2/accounts/{_account_id}/runs/", data=None)

@pytest.mark.parametrize(
argnames="conn_id, account_id",
Expand All @@ -469,7 +474,7 @@ def test_get_job_run(self, mock_http_run, mock_paginate, conn_id, account_id):

_account_id = account_id or DEFAULT_ACCOUNT_ID
hook.run.assert_called_once_with(
endpoint=f"{_account_id}/runs/{RUN_ID}/", data={"include_related": None}
endpoint=f"api/v2/accounts/{_account_id}/runs/{RUN_ID}/", data={"include_related": None}
)
hook._paginate.assert_not_called()

Expand All @@ -488,7 +493,7 @@ def test_get_job_run_with_payload(self, mock_http_run, mock_paginate, conn_id, a

_account_id = account_id or DEFAULT_ACCOUNT_ID
hook.run.assert_called_once_with(
endpoint=f"{_account_id}/runs/{RUN_ID}/", data={"include_related": ["triggers"]}
endpoint=f"api/v2/accounts/{_account_id}/runs/{RUN_ID}/", data={"include_related": ["triggers"]}
)
hook._paginate.assert_not_called()

Expand Down Expand Up @@ -543,7 +548,9 @@ def test_cancel_job_run(self, mock_http_run, mock_paginate, conn_id, account_id)
assert hook.method == "POST"

_account_id = account_id or DEFAULT_ACCOUNT_ID
hook.run.assert_called_once_with(endpoint=f"{_account_id}/runs/{RUN_ID}/cancel/", data=None)
hook.run.assert_called_once_with(
endpoint=f"api/v2/accounts/{_account_id}/runs/{RUN_ID}/cancel/", data=None
)
hook._paginate.assert_not_called()

@pytest.mark.parametrize(
Expand All @@ -561,7 +568,7 @@ def test_list_job_run_artifacts(self, mock_http_run, mock_paginate, conn_id, acc

_account_id = account_id or DEFAULT_ACCOUNT_ID
hook.run.assert_called_once_with(
endpoint=f"{_account_id}/runs/{RUN_ID}/artifacts/", data={"step": None}
endpoint=f"api/v2/accounts/{_account_id}/runs/{RUN_ID}/artifacts/", data={"step": None}
)
hook._paginate.assert_not_called()

Expand All @@ -579,7 +586,9 @@ def test_list_job_run_artifacts_with_payload(self, mock_http_run, mock_paginate,
assert hook.method == "GET"

_account_id = account_id or DEFAULT_ACCOUNT_ID
hook.run.assert_called_once_with(endpoint=f"{_account_id}/runs/{RUN_ID}/artifacts/", data={"step": 2})
hook.run.assert_called_once_with(
endpoint=f"api/v2/accounts/{_account_id}/runs/{RUN_ID}/artifacts/", data={"step": 2}
)
hook._paginate.assert_not_called()

@pytest.mark.parametrize(
Expand All @@ -598,7 +607,7 @@ def test_get_job_run_artifact(self, mock_http_run, mock_paginate, conn_id, accou

_account_id = account_id or DEFAULT_ACCOUNT_ID
hook.run.assert_called_once_with(
endpoint=f"{_account_id}/runs/{RUN_ID}/artifacts/{path}", data={"step": None}
endpoint=f"api/v2/accounts/{_account_id}/runs/{RUN_ID}/artifacts/{path}", data={"step": None}
)
hook._paginate.assert_not_called()

Expand All @@ -618,7 +627,7 @@ def test_get_job_run_artifact_with_payload(self, mock_http_run, mock_paginate, c

_account_id = account_id or DEFAULT_ACCOUNT_ID
hook.run.assert_called_once_with(
endpoint=f"{_account_id}/runs/{RUN_ID}/artifacts/{path}", data={"step": 2}
endpoint=f"api/v2/accounts/{_account_id}/runs/{RUN_ID}/artifacts/{path}", data={"step": 2}
)
hook._paginate.assert_not_called()

Expand Down