Skip to content

Commit

Permalink
Dynamic setting up of artifact versinos for Datafusion pipelines (#34068
Browse files Browse the repository at this point in the history
)
  • Loading branch information
moiseenkov authored Sep 4, 2023
1 parent ba59f34 commit c88e746
Show file tree
Hide file tree
Showing 5 changed files with 80 additions and 294 deletions.
21 changes: 19 additions & 2 deletions airflow/providers/google/cloud/hooks/datafusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ def _base_url(instance_url: str, namespace: str) -> str:
return os.path.join(instance_url, "v3", "namespaces", quote(namespace), "apps")

def _cdap_request(
self, url: str, method: str, body: list | dict | None = None
self, url: str, method: str, body: list | dict | None = None, params: dict | None = None
) -> google.auth.transport.Response:
headers: dict[str, str] = {"Content-Type": "application/json"}
request = google.auth.transport.requests.Request()
Expand All @@ -163,7 +163,7 @@ def _cdap_request(

payload = json.dumps(body) if body else None

response = request(method=method, url=url, headers=headers, body=payload)
response = request(method=method, url=url, headers=headers, body=payload, params=params)
return response

@staticmethod
Expand Down Expand Up @@ -282,6 +282,23 @@ def get_instance(self, instance_name: str, location: str, project_id: str) -> di
)
return instance

def get_instance_artifacts(
self, instance_url: str, namespace: str = "default", scope: str = "SYSTEM"
) -> Any:
url = os.path.join(
instance_url,
"v3",
"namespaces",
quote(namespace),
"artifacts",
)
response = self._cdap_request(url=url, method="GET", params={"scope": scope})
self._check_response_status_and_data(
response, f"Retrieving an instance artifacts failed with code {response.status}"
)
content = json.loads(response.data)
return content

@GoogleBaseHook.fallback_to_default_project_id
def patch_instance(
self,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ It is not possible to use both asynchronous and deferrable parameters at the sam
Please, check the example of using deferrable mode:
:class:`~airflow.providers.google.cloud.operators.datafusion.CloudDataFusionStartPipelineOperator`.

.. exampleinclude:: /../../tests/system/providers/google/cloud/datafusion/example_datafusion_async.py
.. exampleinclude:: /../../tests/system/providers/google/cloud/datafusion/example_datafusion.py
:language: python
:dedent: 4
:start-after: [START howto_cloud_data_fusion_start_pipeline_def]
Expand Down
25 changes: 23 additions & 2 deletions tests/providers/google/cloud/hooks/test_datafusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,24 @@ def test_get_instance(self, get_conn_mock, hook):
assert result == "value"
method_mock.assert_called_once_with(name=hook._name(PROJECT_ID, LOCATION, INSTANCE_NAME))

@mock.patch(HOOK_STR.format("DataFusionHook._cdap_request"))
def test_get_instance_artifacts(self, mock_request, hook):
scope = "SYSTEM"
artifact = {
"name": "test-artifact",
"version": "1.2.3",
"scope": scope,
}
mock_request.return_value = mock.MagicMock(status=200, data=json.dumps([artifact]))

hook.get_instance_artifacts(instance_url=INSTANCE_URL, scope=scope)

mock_request.assert_called_with(
url=f"{INSTANCE_URL}/v3/namespaces/default/artifacts",
method="GET",
params={"scope": scope},
)

@mock.patch("google.auth.transport.requests.Request")
@mock.patch(HOOK_STR.format("DataFusionHook.get_credentials"))
def test_cdap_request(self, get_credentials_mock, mock_request, hook):
Expand All @@ -177,14 +195,17 @@ def test_cdap_request(self, get_credentials_mock, mock_request, hook):
request = mock_request.return_value
request.return_value = mock.MagicMock()
body = {"data": "value"}
params = {"param_key": "param_value"}

result = hook._cdap_request(url=url, method=method, body=body)
result = hook._cdap_request(url=url, method=method, body=body, params=params)
mock_request.assert_called_once_with()
get_credentials_mock.assert_called_once_with()
get_credentials_mock.return_value.before_request.assert_called_once_with(
request=request, method=method, url=url, headers=headers
)
request.assert_called_once_with(method=method, url=url, headers=headers, body=json.dumps(body))
request.assert_called_once_with(
method=method, url=url, headers=headers, body=json.dumps(body), params=params
)
assert result == request.return_value

@mock.patch(HOOK_STR.format("DataFusionHook._cdap_request"))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
from datetime import datetime

from airflow import models
from airflow.decorators import task
from airflow.providers.google.cloud.hooks.datafusion import DataFusionHook
from airflow.providers.google.cloud.operators.datafusion import (
CloudDataFusionCreateInstanceOperator,
CloudDataFusionCreatePipelineOperator,
Expand Down Expand Up @@ -61,7 +63,7 @@
PIPELINE = {
"artifact": {
"name": "cdap-data-pipeline",
"version": "6.8.3",
"version": "{{ task_instance.xcom_pull(task_ids='get_artifacts_versions')['cdap-data-pipeline'] }}",
"scope": "SYSTEM",
},
"description": "Data Pipeline Application",
Expand All @@ -82,7 +84,12 @@
"name": "GCSFile",
"type": "batchsource",
"label": "GCS",
"artifact": {"name": "google-cloud", "version": "0.21.2", "scope": "SYSTEM"},
"artifact": {
"name": "google-cloud",
"version": "{{ task_instance.xcom_pull(task_ids='get_artifacts_versions')\
['google-cloud'] }}",
"scope": "SYSTEM",
},
"properties": {
"project": "auto-detect",
"format": "text",
Expand Down Expand Up @@ -111,7 +118,12 @@
"name": "GCS",
"type": "batchsink",
"label": "GCS2",
"artifact": {"name": "google-cloud", "version": "0.21.2", "scope": "SYSTEM"},
"artifact": {
"name": "google-cloud",
"version": "{{ task_instance.xcom_pull(task_ids='get_artifacts_versions')\
['google-cloud'] }}",
"scope": "SYSTEM",
},
"properties": {
"project": "auto-detect",
"suffix": "yyyy-MM-dd-HH-mm",
Expand Down Expand Up @@ -147,6 +159,9 @@
}
# [END howto_data_fusion_env_variables]

CloudDataFusionCreatePipelineOperator.template_fields += ("pipeline",)


with models.DAG(
DAG_ID,
start_date=datetime(2021, 1, 1),
Expand Down Expand Up @@ -196,6 +211,13 @@
)
# [END howto_cloud_data_fusion_update_instance_operator]

@task(task_id="get_artifacts_versions")
def get_artifacts_versions(ti) -> dict:
hook = DataFusionHook()
instance_url = ti.xcom_pull(task_ids="get_instance", key="return_value")["apiEndpoint"]
artifacts = hook.get_instance_artifacts(instance_url=instance_url, namespace="default")
return {item["name"]: item["version"] for item in artifacts}

# [START howto_cloud_data_fusion_create_pipeline]
create_pipeline = CloudDataFusionCreatePipelineOperator(
location=LOCATION,
Expand All @@ -221,6 +243,16 @@
)
# [END howto_cloud_data_fusion_start_pipeline]

# [START howto_cloud_data_fusion_start_pipeline_def]
start_pipeline_def = CloudDataFusionStartPipelineOperator(
location=LOCATION,
pipeline_name=PIPELINE_NAME,
instance_name=INSTANCE_NAME,
task_id="start_pipeline_def",
deferrable=True,
)
# [END howto_cloud_data_fusion_start_pipeline_def]

# [START howto_cloud_data_fusion_start_pipeline_async]
start_pipeline_async = CloudDataFusionStartPipelineOperator(
location=LOCATION,
Expand Down Expand Up @@ -284,10 +316,12 @@
# TEST BODY
>> create_instance
>> get_instance
>> get_artifacts_versions()
>> restart_instance
>> update_instance
>> create_pipeline
>> list_pipelines
>> start_pipeline_def
>> start_pipeline_async
>> start_pipeline_sensor
>> start_pipeline
Expand Down
Loading

0 comments on commit c88e746

Please sign in to comment.