From 07c3c8213b65195dae531e656db388b80bc4b240 Mon Sep 17 00:00:00 2001 From: Maksim Moiseenkov Date: Fri, 6 Oct 2023 12:21:11 +0000 Subject: [PATCH] Make Dataprep system test self-sufficient --- .../providers/google/cloud/hooks/dataprep.py | 102 ++++- .../google/cloud/operators/dataprep.py | 4 +- .../google/cloud/hooks/test_dataprep.py | 383 +++++++++++++++++- .../google/cloud/operators/test_dataprep.py | 2 +- .../google/cloud/dataprep/example_dataprep.py | 218 ++++++++-- 5 files changed, 652 insertions(+), 57 deletions(-) diff --git a/airflow/providers/google/cloud/hooks/dataprep.py b/airflow/providers/google/cloud/hooks/dataprep.py index c01a48d5ae421..9e006fa99ffe4 100644 --- a/airflow/providers/google/cloud/hooks/dataprep.py +++ b/airflow/providers/google/cloud/hooks/dataprep.py @@ -72,9 +72,10 @@ class GoogleDataprepHook(BaseHook): conn_type = "dataprep" hook_name = "Google Dataprep" - def __init__(self, dataprep_conn_id: str = default_conn_name) -> None: + def __init__(self, dataprep_conn_id: str = default_conn_name, api_version: str = "v4") -> None: super().__init__() self.dataprep_conn_id = dataprep_conn_id + self.api_version = api_version conn = self.get_connection(self.dataprep_conn_id) extras = conn.extra_dejson self._token = _get_field(extras, "token") @@ -95,7 +96,7 @@ def get_jobs_for_job_group(self, job_id: int) -> dict[str, Any]: :param job_id: The ID of the job that will be fetched """ - endpoint_path = f"v4/jobGroups/{job_id}/jobs" + endpoint_path = f"{self.api_version}/jobGroups/{job_id}/jobs" url: str = urljoin(self._base_url, endpoint_path) response = requests.get(url, headers=self._headers) self._raise_for_status(response) @@ -113,7 +114,7 @@ def get_job_group(self, job_group_id: int, embed: str, include_deleted: bool) -> :param include_deleted: if set to "true", will include deleted objects """ params: dict[str, Any] = {"embed": embed, "includeDeleted": include_deleted} - endpoint_path = f"v4/jobGroups/{job_group_id}" + endpoint_path = f"{self.api_version}/jobGroups/{job_group_id}" url: str = urljoin(self._base_url, endpoint_path) response = requests.get(url, headers=self._headers, params=params) self._raise_for_status(response) @@ -131,12 +132,26 @@ def run_job_group(self, body_request: dict) -> dict[str, Any]: :param body_request: The identifier for the recipe you would like to run. """ - endpoint_path = "v4/jobGroups" + endpoint_path = f"{self.api_version}/jobGroups" url: str = urljoin(self._base_url, endpoint_path) response = requests.post(url, headers=self._headers, data=json.dumps(body_request)) self._raise_for_status(response) return response.json() + @retry(stop=stop_after_attempt(5), wait=wait_exponential(multiplier=1, max=10)) + def create_flow(self, *, body_request: dict) -> dict: + """ + Creates flow. + + :param body_request: Body of the POST request to be sent. + For more details check https://clouddataprep.com/documentation/api#operation/createFlow + """ + endpoint = f"/{self.api_version}/flows" + url: str = urljoin(self._base_url, endpoint) + response = requests.post(url, headers=self._headers, data=json.dumps(body_request)) + self._raise_for_status(response) + return response.json() + @retry(stop=stop_after_attempt(5), wait=wait_exponential(multiplier=1, max=10)) def copy_flow( self, *, flow_id: int, name: str = "", description: str = "", copy_datasources: bool = False @@ -149,7 +164,7 @@ def copy_flow( :param description: Description of the copy of the flow :param copy_datasources: Bool value to define should copies of data inputs be made or not. """ - endpoint_path = f"v4/flows/{flow_id}/copy" + endpoint_path = f"{self.api_version}/flows/{flow_id}/copy" url: str = urljoin(self._base_url, endpoint_path) body_request = { "name": name, @@ -167,7 +182,7 @@ def delete_flow(self, *, flow_id: int) -> None: :param flow_id: ID of the flow to be copied """ - endpoint_path = f"v4/flows/{flow_id}" + endpoint_path = f"{self.api_version}/flows/{flow_id}" url: str = urljoin(self._base_url, endpoint_path) response = requests.delete(url, headers=self._headers) self._raise_for_status(response) @@ -180,7 +195,7 @@ def run_flow(self, *, flow_id: int, body_request: dict) -> dict: :param flow_id: ID of the flow to be copied :param body_request: Body of the POST request to be sent. """ - endpoint = f"v4/flows/{flow_id}/run" + endpoint = f"{self.api_version}/flows/{flow_id}/run" url: str = urljoin(self._base_url, endpoint) response = requests.post(url, headers=self._headers, data=json.dumps(body_request)) self._raise_for_status(response) @@ -193,7 +208,7 @@ def get_job_group_status(self, *, job_group_id: int) -> JobGroupStatuses: :param job_group_id: ID of the job group to check """ - endpoint = f"/v4/jobGroups/{job_group_id}/status" + endpoint = f"/{self.api_version}/jobGroups/{job_group_id}/status" url: str = urljoin(self._base_url, endpoint) response = requests.get(url, headers=self._headers) self._raise_for_status(response) @@ -205,3 +220,74 @@ def _raise_for_status(self, response: requests.models.Response) -> None: except HTTPError: self.log.error(response.json().get("exception")) raise + + @retry(stop=stop_after_attempt(5), wait=wait_exponential(multiplier=1, max=10)) + def create_imported_dataset(self, *, body_request: dict) -> dict: + """ + Creates imported dataset. + + :param body_request: Body of the POST request to be sent. + For more details check https://clouddataprep.com/documentation/api#operation/createImportedDataset + """ + endpoint = f"/{self.api_version}/importedDatasets" + url: str = urljoin(self._base_url, endpoint) + response = requests.post(url, headers=self._headers, data=json.dumps(body_request)) + self._raise_for_status(response) + return response.json() + + @retry(stop=stop_after_attempt(5), wait=wait_exponential(multiplier=1, max=10)) + def create_wrangled_dataset(self, *, body_request: dict) -> dict: + """ + Creates wrangled dataset. + + :param body_request: Body of the POST request to be sent. + For more details check + https://clouddataprep.com/documentation/api#operation/createWrangledDataset + """ + endpoint = f"/{self.api_version}/wrangledDatasets" + url: str = urljoin(self._base_url, endpoint) + response = requests.post(url, headers=self._headers, data=json.dumps(body_request)) + self._raise_for_status(response) + return response.json() + + @retry(stop=stop_after_attempt(5), wait=wait_exponential(multiplier=1, max=10)) + def create_output_object(self, *, body_request: dict) -> dict: + """ + Creates output. + + :param body_request: Body of the POST request to be sent. + For more details check + https://clouddataprep.com/documentation/api#operation/createOutputObject + """ + endpoint = f"/{self.api_version}/outputObjects" + url: str = urljoin(self._base_url, endpoint) + response = requests.post(url, headers=self._headers, data=json.dumps(body_request)) + self._raise_for_status(response) + return response.json() + + @retry(stop=stop_after_attempt(5), wait=wait_exponential(multiplier=1, max=10)) + def create_write_settings(self, *, body_request: dict) -> dict: + """ + Creates write settings. + + :param body_request: Body of the POST request to be sent. + For more details check + https://clouddataprep.com/documentation/api#tag/createWriteSetting + """ + endpoint = f"/{self.api_version}/writeSettings" + url: str = urljoin(self._base_url, endpoint) + response = requests.post(url, headers=self._headers, data=json.dumps(body_request)) + self._raise_for_status(response) + return response.json() + + @retry(stop=stop_after_attempt(5), wait=wait_exponential(multiplier=1, max=10)) + def delete_imported_dataset(self, *, dataset_id: int) -> None: + """ + Deletes imported dataset. + + :param dataset_id: ID of the imported dataset for removal. + """ + endpoint = f"/{self.api_version}/importedDatasets/{dataset_id}" + url: str = urljoin(self._base_url, endpoint) + response = requests.delete(url, headers=self._headers) + self._raise_for_status(response) diff --git a/airflow/providers/google/cloud/operators/dataprep.py b/airflow/providers/google/cloud/operators/dataprep.py index 7f19f6993b7ce..59710293e6b16 100644 --- a/airflow/providers/google/cloud/operators/dataprep.py +++ b/airflow/providers/google/cloud/operators/dataprep.py @@ -51,13 +51,13 @@ def __init__( **kwargs, ) -> None: super().__init__(**kwargs) - self.dataprep_conn_id = (dataprep_conn_id,) + self.dataprep_conn_id = dataprep_conn_id self.job_group_id = job_group_id def execute(self, context: Context) -> dict: self.log.info("Fetching data for job with id: %d ...", self.job_group_id) hook = GoogleDataprepHook( - dataprep_conn_id="dataprep_default", + dataprep_conn_id=self.dataprep_conn_id, ) response = hook.get_jobs_for_job_group(job_id=int(self.job_group_id)) return response diff --git a/tests/providers/google/cloud/hooks/test_dataprep.py b/tests/providers/google/cloud/hooks/test_dataprep.py index e29d2be3dca9e..a0cef77ae9ebc 100644 --- a/tests/providers/google/cloud/hooks/test_dataprep.py +++ b/tests/providers/google/cloud/hooks/test_dataprep.py @@ -35,7 +35,12 @@ EMBED = "" INCLUDE_DELETED = False DATA = {"wrangledDataset": {"id": RECIPE_ID}} -URL = "https://api.clouddataprep.com/v4/jobGroups" +URL_BASE = "https://api.clouddataprep.com" +URL_JOB_GROUPS = URL_BASE + "/v4/jobGroups" +URL_IMPORTED_DATASETS = URL_BASE + "/v4/importedDatasets" +URL_WRANGLED_DATASETS = URL_BASE + "/v4/wrangledDatasets" +URL_OUTPUT_OBJECTS = URL_BASE + "/v4/outputObjects" +URL_WRITE_SETTINGS = URL_BASE + "/v4/writeSettings" class TestGoogleDataprepHook: @@ -43,12 +48,41 @@ def setup_method(self): with mock.patch("airflow.hooks.base.BaseHook.get_connection") as conn: conn.return_value.extra_dejson = EXTRA self.hook = GoogleDataprepHook(dataprep_conn_id="dataprep_default") + self._imported_dataset_id = 12345 + self._create_imported_dataset_body_request = { + "uri": "gs://test/uri", + "name": "test_name", + } + self._create_wrangled_dataset_body_request = { + "importedDataset": {"id": "test_dataset_id"}, + "flow": {"id": "test_flow_id"}, + "name": "test_dataset_name", + } + self._create_output_object_body_request = { + "execution": "dataflow", + "profiler": False, + "flowNodeId": "test_flow_node_id", + } + self._create_write_settings_body_request = { + "path": "gs://test/path", + "action": "create", + "format": "csv", + "outputObjectId": "test_output_object_id", + } + self._expected_create_imported_dataset_hook_data = json.dumps( + self._create_imported_dataset_body_request + ) + self._expected_create_wrangled_dataset_hook_data = json.dumps( + self._create_wrangled_dataset_body_request + ) + self._expected_create_output_object_hook_data = json.dumps(self._create_output_object_body_request) + self._expected_create_write_settings_hook_data = json.dumps(self._create_write_settings_body_request) @patch("airflow.providers.google.cloud.hooks.dataprep.requests.get") def test_get_jobs_for_job_group_should_be_called_once_with_params(self, mock_get_request): self.hook.get_jobs_for_job_group(JOB_ID) mock_get_request.assert_called_once_with( - f"{URL}/{JOB_ID}/jobs", + f"{URL_JOB_GROUPS}/{JOB_ID}/jobs", headers={"Content-Type": "application/json", "Authorization": f"Bearer {TOKEN}"}, ) @@ -93,7 +127,7 @@ def test_get_jobs_for_job_group_raise_error_after_five_calls(self, mock_get_requ def test_get_job_group_should_be_called_once_with_params(self, mock_get_request): self.hook.get_job_group(JOB_ID, EMBED, INCLUDE_DELETED) mock_get_request.assert_called_once_with( - f"{URL}/{JOB_ID}", + f"{URL_JOB_GROUPS}/{JOB_ID}", headers={ "Content-Type": "application/json", "Authorization": f"Bearer {TOKEN}", @@ -148,7 +182,7 @@ def test_get_job_group_raise_error_after_five_calls(self, mock_get_request): def test_run_job_group_should_be_called_once_with_params(self, mock_get_request): self.hook.run_job_group(body_request=DATA) mock_get_request.assert_called_once_with( - f"{URL}", + f"{URL_JOB_GROUPS}", headers={ "Content-Type": "application/json", "Authorization": f"Bearer {TOKEN}", @@ -203,7 +237,7 @@ def test_run_job_group_raise_error_after_five_calls(self, mock_get_request): def test_get_job_group_status_should_be_called_once_with_params(self, mock_get_request): self.hook.get_job_group_status(job_group_id=JOB_ID) mock_get_request.assert_called_once_with( - f"{URL}/{JOB_ID}/status", + f"{URL_JOB_GROUPS}/{JOB_ID}/status", headers={ "Content-Type": "application/json", "Authorization": f"Bearer {TOKEN}", @@ -266,12 +300,290 @@ def test_conn_extra_backcompat_prefix(self, uri): assert hook._token == "abc" assert hook._base_url == "abc" + @patch("airflow.providers.google.cloud.hooks.dataprep.requests.post") + def test_create_imported_dataset_should_be_called_once_with_params(self, mock_post_request): + self.hook.create_imported_dataset(body_request=self._create_imported_dataset_body_request) + mock_post_request.assert_called_once_with( + URL_IMPORTED_DATASETS, + headers={ + "Content-Type": "application/json", + "Authorization": f"Bearer {TOKEN}", + }, + data=self._expected_create_imported_dataset_hook_data, + ) + + @patch( + "airflow.providers.google.cloud.hooks.dataprep.requests.post", + side_effect=[HTTPError(), mock.MagicMock()], + ) + def test_create_imported_dataset_should_pass_after_retry(self, mock_post_request): + self.hook.create_imported_dataset(body_request=self._create_imported_dataset_body_request) + assert mock_post_request.call_count == 2 + + @patch( + "airflow.providers.google.cloud.hooks.dataprep.requests.post", + side_effect=[mock.MagicMock(), HTTPError()], + ) + def test_create_imported_dataset_retry_after_success(self, mock_post_request): + self.hook.create_imported_dataset.retry.sleep = mock.Mock() + self.hook.create_imported_dataset(body_request=self._create_imported_dataset_body_request) + assert mock_post_request.call_count == 1 + + @patch( + "airflow.providers.google.cloud.hooks.dataprep.requests.post", + side_effect=[ + HTTPError(), + HTTPError(), + HTTPError(), + HTTPError(), + mock.MagicMock(), + ], + ) + def test_create_imported_dataset_four_errors(self, mock_post_request): + self.hook.create_imported_dataset.retry.sleep = mock.Mock() + self.hook.create_imported_dataset(body_request=self._create_imported_dataset_body_request) + assert mock_post_request.call_count == 5 + + @patch( + "airflow.providers.google.cloud.hooks.dataprep.requests.post", + side_effect=[HTTPError(), HTTPError(), HTTPError(), HTTPError(), HTTPError()], + ) + def test_create_imported_dataset_five_calls(self, mock_post_request): + with pytest.raises(RetryError) as ctx: + self.hook.create_imported_dataset.retry.sleep = mock.Mock() + self.hook.create_imported_dataset(body_request=self._create_imported_dataset_body_request) + assert "HTTPError" in str(ctx.value) + assert mock_post_request.call_count == 5 + + @patch("airflow.providers.google.cloud.hooks.dataprep.requests.post") + def test_create_wrangled_dataset_should_be_called_once_with_params(self, mock_post_request): + self.hook.create_wrangled_dataset(body_request=self._create_wrangled_dataset_body_request) + mock_post_request.assert_called_once_with( + URL_WRANGLED_DATASETS, + headers={ + "Content-Type": "application/json", + "Authorization": f"Bearer {TOKEN}", + }, + data=self._expected_create_wrangled_dataset_hook_data, + ) + + @patch( + "airflow.providers.google.cloud.hooks.dataprep.requests.post", + side_effect=[HTTPError(), mock.MagicMock()], + ) + def test_create_wrangled_dataset_should_pass_after_retry(self, mock_post_request): + self.hook.create_wrangled_dataset(body_request=self._create_wrangled_dataset_body_request) + assert mock_post_request.call_count == 2 + + @patch( + "airflow.providers.google.cloud.hooks.dataprep.requests.post", + side_effect=[mock.MagicMock(), HTTPError()], + ) + def test_create_wrangled_dataset_retry_after_success(self, mock_post_request): + self.hook.create_wrangled_dataset.retry.sleep = mock.Mock() + self.hook.create_wrangled_dataset(body_request=self._create_wrangled_dataset_body_request) + assert mock_post_request.call_count == 1 + + @patch( + "airflow.providers.google.cloud.hooks.dataprep.requests.post", + side_effect=[ + HTTPError(), + HTTPError(), + HTTPError(), + HTTPError(), + mock.MagicMock(), + ], + ) + def test_create_wrangled_dataset_four_errors(self, mock_post_request): + self.hook.create_wrangled_dataset.retry.sleep = mock.Mock() + self.hook.create_wrangled_dataset(body_request=self._create_wrangled_dataset_body_request) + assert mock_post_request.call_count == 5 + + @patch( + "airflow.providers.google.cloud.hooks.dataprep.requests.post", + side_effect=[HTTPError(), HTTPError(), HTTPError(), HTTPError(), HTTPError()], + ) + def test_create_wrangled_dataset_five_calls(self, mock_post_request): + with pytest.raises(RetryError) as ctx: + self.hook.create_wrangled_dataset.retry.sleep = mock.Mock() + self.hook.create_wrangled_dataset(body_request=self._create_wrangled_dataset_body_request) + assert "HTTPError" in str(ctx.value) + assert mock_post_request.call_count == 5 + + @patch("airflow.providers.google.cloud.hooks.dataprep.requests.post") + def test_create_output_object_should_be_called_once_with_params(self, mock_post_request): + self.hook.create_output_object(body_request=self._create_output_object_body_request) + mock_post_request.assert_called_once_with( + URL_OUTPUT_OBJECTS, + headers={ + "Content-Type": "application/json", + "Authorization": f"Bearer {TOKEN}", + }, + data=self._expected_create_output_object_hook_data, + ) + + @patch( + "airflow.providers.google.cloud.hooks.dataprep.requests.post", + side_effect=[HTTPError(), mock.MagicMock()], + ) + def test_create_output_objects_should_pass_after_retry(self, mock_post_request): + self.hook.create_output_object(body_request=self._create_output_object_body_request) + assert mock_post_request.call_count == 2 + + @patch( + "airflow.providers.google.cloud.hooks.dataprep.requests.post", + side_effect=[mock.MagicMock(), HTTPError()], + ) + def test_create_output_objects_retry_after_success(self, mock_post_request): + self.hook.create_output_object.retry.sleep = mock.Mock() + self.hook.create_output_object(body_request=self._create_output_object_body_request) + assert mock_post_request.call_count == 1 + + @patch( + "airflow.providers.google.cloud.hooks.dataprep.requests.post", + side_effect=[ + HTTPError(), + HTTPError(), + HTTPError(), + HTTPError(), + mock.MagicMock(), + ], + ) + def test_create_output_objects_four_errors(self, mock_post_request): + self.hook.create_output_object.retry.sleep = mock.Mock() + self.hook.create_output_object(body_request=self._create_output_object_body_request) + assert mock_post_request.call_count == 5 + + @patch( + "airflow.providers.google.cloud.hooks.dataprep.requests.post", + side_effect=[HTTPError(), HTTPError(), HTTPError(), HTTPError(), HTTPError()], + ) + def test_create_output_objects_five_calls(self, mock_post_request): + with pytest.raises(RetryError) as ctx: + self.hook.create_output_object.retry.sleep = mock.Mock() + self.hook.create_output_object(body_request=self._create_output_object_body_request) + assert "HTTPError" in str(ctx.value) + assert mock_post_request.call_count == 5 + + @patch("airflow.providers.google.cloud.hooks.dataprep.requests.post") + def test_create_write_settings_should_be_called_once_with_params(self, mock_post_request): + self.hook.create_write_settings(body_request=self._create_write_settings_body_request) + mock_post_request.assert_called_once_with( + URL_WRITE_SETTINGS, + headers={ + "Content-Type": "application/json", + "Authorization": f"Bearer {TOKEN}", + }, + data=self._expected_create_write_settings_hook_data, + ) + + @patch( + "airflow.providers.google.cloud.hooks.dataprep.requests.post", + side_effect=[HTTPError(), mock.MagicMock()], + ) + def test_create_write_settings_should_pass_after_retry(self, mock_post_request): + self.hook.create_write_settings(body_request=self._create_write_settings_body_request) + assert mock_post_request.call_count == 2 + + @patch( + "airflow.providers.google.cloud.hooks.dataprep.requests.post", + side_effect=[mock.MagicMock(), HTTPError()], + ) + def test_create_write_settings_retry_after_success(self, mock_post_request): + self.hook.create_write_settings.retry.sleep = mock.Mock() + self.hook.create_write_settings(body_request=self._create_write_settings_body_request) + assert mock_post_request.call_count == 1 + + @patch( + "airflow.providers.google.cloud.hooks.dataprep.requests.post", + side_effect=[ + HTTPError(), + HTTPError(), + HTTPError(), + HTTPError(), + mock.MagicMock(), + ], + ) + def test_create_write_settings_four_errors(self, mock_post_request): + self.hook.create_write_settings.retry.sleep = mock.Mock() + self.hook.create_write_settings(body_request=self._create_write_settings_body_request) + assert mock_post_request.call_count == 5 + + @patch( + "airflow.providers.google.cloud.hooks.dataprep.requests.post", + side_effect=[HTTPError(), HTTPError(), HTTPError(), HTTPError(), HTTPError()], + ) + def test_create_write_settings_five_calls(self, mock_post_request): + with pytest.raises(RetryError) as ctx: + self.hook.create_write_settings.retry.sleep = mock.Mock() + self.hook.create_write_settings(body_request=self._create_write_settings_body_request) + assert "HTTPError" in str(ctx.value) + assert mock_post_request.call_count == 5 + + @patch("airflow.providers.google.cloud.hooks.dataprep.requests.delete") + def test_delete_imported_dataset_should_be_called_once_with_params(self, mock_delete_request): + self.hook.delete_imported_dataset(dataset_id=self._imported_dataset_id) + mock_delete_request.assert_called_once_with( + f"{URL_IMPORTED_DATASETS}/{self._imported_dataset_id}", + headers={ + "Content-Type": "application/json", + "Authorization": f"Bearer {TOKEN}", + }, + ) + + @patch( + "airflow.providers.google.cloud.hooks.dataprep.requests.delete", + side_effect=[HTTPError(), mock.MagicMock()], + ) + def test_delete_imported_dataset_should_pass_after_retry(self, mock_delete_request): + self.hook.delete_imported_dataset(dataset_id=self._imported_dataset_id) + assert mock_delete_request.call_count == 2 + + @patch( + "airflow.providers.google.cloud.hooks.dataprep.requests.delete", + side_effect=[mock.MagicMock(), HTTPError()], + ) + def test_delete_imported_dataset_retry_after_success(self, mock_delete_request): + self.hook.delete_imported_dataset.retry.sleep = mock.Mock() + self.hook.delete_imported_dataset(dataset_id=self._imported_dataset_id) + assert mock_delete_request.call_count == 1 + + @patch( + "airflow.providers.google.cloud.hooks.dataprep.requests.delete", + side_effect=[ + HTTPError(), + HTTPError(), + HTTPError(), + HTTPError(), + mock.MagicMock(), + ], + ) + def test_delete_imported_dataset_four_errors(self, mock_delete_request): + self.hook.delete_imported_dataset.retry.sleep = mock.Mock() + self.hook.delete_imported_dataset(dataset_id=self._imported_dataset_id) + assert mock_delete_request.call_count == 5 + + @patch( + "airflow.providers.google.cloud.hooks.dataprep.requests.delete", + side_effect=[HTTPError(), HTTPError(), HTTPError(), HTTPError(), HTTPError()], + ) + def test_delete_imported_dataset_five_calls(self, mock_delete_request): + with pytest.raises(RetryError) as ctx: + self.hook.delete_imported_dataset.retry.sleep = mock.Mock() + self.hook.delete_imported_dataset(dataset_id=self._imported_dataset_id) + assert "HTTPError" in str(ctx.value) + assert mock_delete_request.call_count == 5 + class TestGoogleDataprepFlowPathHooks: _url = "https://api.clouddataprep.com/v4/flows" def setup_method(self): self._flow_id = 1234567 + self._create_flow_body_request = { + "name": "test_name", + "description": "Test description", + } self._expected_copy_flow_hook_data = json.dumps( { "name": "", @@ -280,10 +592,71 @@ def setup_method(self): } ) self._expected_run_flow_hook_data = json.dumps({}) + self._expected_create_flow_hook_data = json.dumps( + { + "name": "test_name", + "description": "Test description", + } + ) with mock.patch("airflow.hooks.base.BaseHook.get_connection") as conn: conn.return_value.extra_dejson = EXTRA self.hook = GoogleDataprepHook(dataprep_conn_id="dataprep_default") + @patch("airflow.providers.google.cloud.hooks.dataprep.requests.post") + def test_create_flow_should_be_called_once_with_params(self, mock_post_request): + self.hook.create_flow(body_request=self._create_flow_body_request) + mock_post_request.assert_called_once_with( + self._url, + headers={ + "Content-Type": "application/json", + "Authorization": f"Bearer {TOKEN}", + }, + data=self._expected_create_flow_hook_data, + ) + + @patch( + "airflow.providers.google.cloud.hooks.dataprep.requests.post", + side_effect=[HTTPError(), mock.MagicMock()], + ) + def test_create_flow_should_pass_after_retry(self, mock_post_request): + self.hook.create_flow(body_request=self._create_flow_body_request) + assert mock_post_request.call_count == 2 + + @patch( + "airflow.providers.google.cloud.hooks.dataprep.requests.post", + side_effect=[mock.MagicMock(), HTTPError()], + ) + def test_create_flow_should_not_retry_after_success(self, mock_post_request): + self.hook.create_flow.retry.sleep = mock.Mock() + self.hook.create_flow(body_request=self._create_flow_body_request) + assert mock_post_request.call_count == 1 + + @patch( + "airflow.providers.google.cloud.hooks.dataprep.requests.post", + side_effect=[ + HTTPError(), + HTTPError(), + HTTPError(), + HTTPError(), + mock.MagicMock(), + ], + ) + def test_create_flow_should_retry_after_four_errors(self, mock_post_request): + self.hook.create_flow.retry.sleep = mock.Mock() + self.hook.create_flow(body_request=self._create_flow_body_request) + assert mock_post_request.call_count == 5 + + @patch( + "airflow.providers.google.cloud.hooks.dataprep.requests.post", + side_effect=[HTTPError(), HTTPError(), HTTPError(), HTTPError(), HTTPError()], + ) + def test_create_flow_raise_error_after_five_calls(self, mock_post_request): + with pytest.raises(RetryError) as ctx: + self.hook.create_flow.retry.sleep = mock.Mock() + self.hook.create_flow(body_request=self._create_flow_body_request) + assert "HTTPError" in str(ctx.value) + assert mock_post_request.call_count == 5 + @patch("airflow.providers.google.cloud.hooks.dataprep.requests.post") def test_copy_flow_should_be_called_once_with_params(self, mock_get_request): self.hook.copy_flow( diff --git a/tests/providers/google/cloud/operators/test_dataprep.py b/tests/providers/google/cloud/operators/test_dataprep.py index 08237d0d5193a..d5800716a8716 100644 --- a/tests/providers/google/cloud/operators/test_dataprep.py +++ b/tests/providers/google/cloud/operators/test_dataprep.py @@ -66,7 +66,7 @@ def test_execute(self, hook_mock): dataprep_conn_id=DATAPREP_CONN_ID, job_group_id=JOB_ID, task_id=TASK_ID ) op.execute(context={}) - hook_mock.assert_called_once_with(dataprep_conn_id="dataprep_default") + hook_mock.assert_called_once_with(dataprep_conn_id=DATAPREP_CONN_ID) hook_mock.return_value.get_jobs_for_job_group.assert_called_once_with(job_id=JOB_ID) diff --git a/tests/system/providers/google/cloud/dataprep/example_dataprep.py b/tests/system/providers/google/cloud/dataprep/example_dataprep.py index c07cd5a4562df..126e3f4c9b6db 100644 --- a/tests/system/providers/google/cloud/dataprep/example_dataprep.py +++ b/tests/system/providers/google/cloud/dataprep/example_dataprep.py @@ -16,13 +16,25 @@ # under the License. """ Example Airflow DAG that shows how to use Google Dataprep. + +This DAG relies on the following OS environment variables + +* SYSTEM_TESTS_DATAPREP_TOKEN - Dataprep API access token. + For generating it please use instruction + https://docs.trifacta.com/display/DP/Manage+API+Access+Tokens#:~:text=Enable%20individual%20access-,Generate%20New%20Token,-Via%20UI. """ from __future__ import annotations +import logging import os from datetime import datetime -from airflow.models.dag import DAG +from airflow import models +from airflow.decorators import task +from airflow.models import Connection +from airflow.models.baseoperator import chain +from airflow.operators.bash import BashOperator +from airflow.providers.google.cloud.hooks.dataprep import GoogleDataprepHook from airflow.providers.google.cloud.operators.dataprep import ( DataprepCopyFlowOperator, DataprepDeleteFlowOperator, @@ -33,31 +45,40 @@ ) from airflow.providers.google.cloud.operators.gcs import GCSCreateBucketOperator, GCSDeleteBucketOperator from airflow.providers.google.cloud.sensors.dataprep import DataprepJobGroupIsFinishedSensor +from airflow.settings import Session from airflow.utils.trigger_rule import TriggerRule ENV_ID = os.environ.get("SYSTEM_TESTS_ENV_ID") DAG_ID = "example_dataprep" +CONNECTION_ID = f"connection_{DAG_ID}_{ENV_ID}".replace("-", "_") +DATAPREP_TOKEN = os.environ.get("SYSTEM_TESTS_DATAPREP_TOKEN", "") GCP_PROJECT_ID = os.environ.get("SYSTEM_TESTS_GCP_PROJECT") GCS_BUCKET_NAME = f"dataprep-bucket-{DAG_ID}-{ENV_ID}" GCS_BUCKET_PATH = f"gs://{GCS_BUCKET_NAME}/task_results/" -FLOW_ID = os.environ.get("FLOW_ID") -RECIPE_ID = os.environ.get("RECIPE_ID") -RECIPE_NAME = os.environ.get("RECIPE_NAME") -WRITE_SETTINGS = ( - { - "writesettings": [ - { - "path": GCS_BUCKET_PATH, - "action": "create", - "format": "csv", - } - ], - }, -) +DATASET_URI = "gs://airflow-system-tests-resources/dataprep/dataset-00000.parquet" +DATASET_NAME = f"dataset_{DAG_ID}_{ENV_ID}".replace("-", "_") +DATASET_WRANGLED_NAME = f"wrangled_{DATASET_NAME}" +DATASET_WRANGLED_ID = "{{ task_instance.xcom_pull('create_wrangled_dataset')['id'] }}" + +FLOW_ID = "{{ task_instance.xcom_pull('create_flow')['id'] }}" +FLOW_COPY_ID = "{{ task_instance.xcom_pull('copy_flow')['id'] }}" +RECIPE_NAME = DATASET_WRANGLED_NAME +WRITE_SETTINGS = { + "writesettings": [ + { + "path": GCS_BUCKET_PATH + f"adhoc_{RECIPE_NAME}.csv", + "action": "create", + "format": "csv", + }, + ], +} + +log = logging.getLogger(__name__) -with DAG( + +with models.DAG( DAG_ID, schedule="@once", start_date=datetime(2021, 1, 1), # Override to match your needs @@ -71,42 +92,128 @@ project_id=GCP_PROJECT_ID, ) + @task + def create_connection(**kwargs) -> None: + connection = Connection( + conn_id=CONNECTION_ID, + description="Example Dataprep connection", + conn_type="dataprep", + extra={"token": DATAPREP_TOKEN}, + ) + session: Session = Session() + if session.query(Connection).filter(Connection.conn_id == CONNECTION_ID).first(): + log.warning("Connection %s already exists", CONNECTION_ID) + return None + session.add(connection) + session.commit() + + create_connection_task = create_connection() + + @task + def create_imported_dataset(): + hook = GoogleDataprepHook(dataprep_conn_id=CONNECTION_ID) + response = hook.create_imported_dataset( + body_request={ + "uri": DATASET_URI, + "name": DATASET_NAME, + } + ) + return response + + create_imported_dataset_task = create_imported_dataset() + + @task + def create_flow(): + hook = GoogleDataprepHook(dataprep_conn_id=CONNECTION_ID) + response = hook.create_flow( + body_request={ + "name": f"test_flow_{DAG_ID}_{ENV_ID}", + "description": "Test flow", + } + ) + return response + + create_flow_task = create_flow() + + @task + def create_wrangled_dataset(flow, imported_dataset): + hook = GoogleDataprepHook(dataprep_conn_id=CONNECTION_ID) + response = hook.create_wrangled_dataset( + body_request={ + "importedDataset": {"id": imported_dataset["id"]}, + "flow": {"id": flow["id"]}, + "name": DATASET_WRANGLED_NAME, + } + ) + return response + + create_wrangled_dataset_task = create_wrangled_dataset(create_flow_task, create_imported_dataset_task) + + @task + def create_output(wrangled_dataset): + hook = GoogleDataprepHook(dataprep_conn_id=CONNECTION_ID) + response = hook.create_output_object( + body_request={ + "execution": "dataflow", + "profiler": False, + "flowNodeId": wrangled_dataset["id"], + } + ) + return response + + create_output_task = create_output(create_wrangled_dataset_task) + + @task + def create_write_settings(output): + hook = GoogleDataprepHook(dataprep_conn_id=CONNECTION_ID) + response = hook.create_write_settings( + body_request={ + "path": GCS_BUCKET_PATH + f"adhoc_{RECIPE_NAME}.csv", + "action": "create", + "format": "csv", + "outputObjectId": output["id"], + } + ) + return response + + create_write_settings_task = create_write_settings(create_output_task) + + # [START how_to_dataprep_copy_flow_operator] + copy_task = DataprepCopyFlowOperator( + task_id="copy_flow", + dataprep_conn_id=CONNECTION_ID, + project_id=GCP_PROJECT_ID, + flow_id=FLOW_ID, + name=f"copy_{DATASET_NAME}", + ) + # [END how_to_dataprep_copy_flow_operator] + # [START how_to_dataprep_run_job_group_operator] run_job_group_task = DataprepRunJobGroupOperator( task_id="run_job_group", + dataprep_conn_id=CONNECTION_ID, project_id=GCP_PROJECT_ID, body_request={ - "wrangledDataset": {"id": RECIPE_ID}, + "wrangledDataset": {"id": DATASET_WRANGLED_ID}, "overrides": WRITE_SETTINGS, }, ) # [END how_to_dataprep_run_job_group_operator] - # [START how_to_dataprep_copy_flow_operator] - copy_task = DataprepCopyFlowOperator( - task_id="copy_flow", - project_id=GCP_PROJECT_ID, - flow_id=FLOW_ID, - name=f"dataprep_example_flow_{DAG_ID}_{ENV_ID}", - ) - # [END how_to_dataprep_copy_flow_operator] - # [START how_to_dataprep_dataprep_run_flow_operator] run_flow_task = DataprepRunFlowOperator( task_id="run_flow", + dataprep_conn_id=CONNECTION_ID, project_id=GCP_PROJECT_ID, - flow_id="{{ task_instance.xcom_pull('copy_flow')['id'] }}", - body_request={ - "overrides": { - RECIPE_NAME: WRITE_SETTINGS, - }, - }, + flow_id=FLOW_COPY_ID, + body_request={}, ) # [END how_to_dataprep_dataprep_run_flow_operator] # [START how_to_dataprep_get_job_group_operator] get_job_group_task = DataprepGetJobGroupOperator( task_id="get_job_group", + dataprep_conn_id=CONNECTION_ID, project_id=GCP_PROJECT_ID, job_group_id="{{ task_instance.xcom_pull('run_flow')['data'][0]['id'] }}", embed="", @@ -117,6 +224,7 @@ # [START how_to_dataprep_get_jobs_for_job_group_operator] get_jobs_for_job_group_task = DataprepGetJobsForJobGroupOperator( task_id="get_jobs_for_job_group", + dataprep_conn_id=CONNECTION_ID, job_group_id="{{ task_instance.xcom_pull('run_flow')['data'][0]['id'] }}", ) # [END how_to_dataprep_get_jobs_for_job_group_operator] @@ -124,6 +232,7 @@ # [START how_to_dataprep_job_group_finished_sensor] check_flow_status_sensor = DataprepJobGroupIsFinishedSensor( task_id="check_flow_status", + dataprep_conn_id=CONNECTION_ID, job_group_id="{{ task_instance.xcom_pull('run_flow')['data'][0]['id'] }}", ) # [END how_to_dataprep_job_group_finished_sensor] @@ -131,6 +240,7 @@ # [START how_to_dataprep_job_group_finished_sensor] check_job_group_status_sensor = DataprepJobGroupIsFinishedSensor( task_id="check_job_group_status", + dataprep_conn_id=CONNECTION_ID, job_group_id="{{ task_instance.xcom_pull('run_job_group')['id'] }}", ) # [END how_to_dataprep_job_group_finished_sensor] @@ -138,29 +248,55 @@ # [START how_to_dataprep_delete_flow_operator] delete_flow_task = DataprepDeleteFlowOperator( task_id="delete_flow", + dataprep_conn_id=CONNECTION_ID, flow_id="{{ task_instance.xcom_pull('copy_flow')['id'] }}", ) # [END how_to_dataprep_delete_flow_operator] delete_flow_task.trigger_rule = TriggerRule.ALL_DONE + delete_flow_task_original = DataprepDeleteFlowOperator( + task_id="delete_flow_original", + dataprep_conn_id=CONNECTION_ID, + flow_id="{{ task_instance.xcom_pull('create_flow')['id'] }}", + trigger_rule=TriggerRule.ALL_DONE, + ) + + @task(trigger_rule=TriggerRule.ALL_DONE) + def delete_dataset(dataset): + hook = GoogleDataprepHook(dataprep_conn_id=CONNECTION_ID) + hook.delete_imported_dataset(dataset_id=dataset["id"]) + + delete_dataset_task = delete_dataset(create_imported_dataset_task) + delete_bucket_task = GCSDeleteBucketOperator( task_id="delete_bucket", bucket_name=GCS_BUCKET_NAME, trigger_rule=TriggerRule.ALL_DONE, ) - ( + delete_connection = BashOperator( + task_id="delete_connection", + bash_command=f"airflow connections delete {CONNECTION_ID}", + trigger_rule=TriggerRule.ALL_DONE, + ) + + chain( # TEST SETUP - create_bucket_task - >> copy_task + create_bucket_task, + create_connection_task, + [create_imported_dataset_task, create_flow_task], + create_wrangled_dataset_task, + create_output_task, + create_write_settings_task, # TEST BODY - >> [run_job_group_task, run_flow_task] - >> get_job_group_task - >> get_jobs_for_job_group_task + copy_task, + [run_job_group_task, run_flow_task], + [get_job_group_task, get_jobs_for_job_group_task], + [check_flow_status_sensor, check_job_group_status_sensor], # TEST TEARDOWN - >> check_flow_status_sensor - >> [delete_flow_task, check_job_group_status_sensor] - >> delete_bucket_task + delete_dataset_task, + [delete_flow_task, delete_flow_task_original], + [delete_bucket_task, delete_connection], ) from tests.system.utils.watcher import watcher