diff --git a/airflow/providers/google/cloud/example_dags/example_dataflow.py b/airflow/providers/google/cloud/example_dags/example_dataflow.py index 0e6fe1f850..8633b3dc1f 100644 --- a/airflow/providers/google/cloud/example_dags/example_dataflow.py +++ b/airflow/providers/google/cloud/example_dags/example_dataflow.py @@ -20,12 +20,14 @@ Example Airflow DAG for Google Cloud Dataflow service """ import os +from urllib.parse import urlparse from airflow import models from airflow.providers.google.cloud.operators.dataflow import ( CheckJobRunning, DataflowCreateJavaJobOperator, DataflowCreatePythonJobOperator, DataflowTemplatedJobStartOperator, ) +from airflow.providers.google.cloud.operators.gcs import GCSToLocalOperator from airflow.utils.dates import days_ago GCP_PROJECT_ID = os.environ.get('GCP_PROJECT_ID', 'example-project') @@ -33,6 +35,11 @@ GCS_STAGING = os.environ.get('GCP_DATAFLOW_GCS_STAGING', 'gs://test-dataflow-example/staging/') GCS_OUTPUT = os.environ.get('GCP_DATAFLOW_GCS_OUTPUT', 'gs://test-dataflow-example/output') GCS_JAR = os.environ.get('GCP_DATAFLOW_JAR', 'gs://test-dataflow-example/word-count-beam-bundled-0.1.jar') +GCS_PYTHON = os.environ.get('GCP_DATAFLOW_PYTHON', 'gs://test-dataflow-example/wordcount_debugging.py') + +GCS_JAR_PARTS = urlparse(GCS_JAR) +GCS_JAR_BUCKET_NAME = GCS_JAR_PARTS.netloc +GCS_JAR_OBJECT_NAME = GCS_JAR_PARTS.path[1:] default_args = { "start_date": days_ago(1), @@ -60,13 +67,49 @@ }, poll_sleep=10, job_class='org.apache.beam.examples.WordCount', - check_if_running=CheckJobRunning.WaitForRun, + check_if_running=CheckJobRunning.IgnoreJob, ) # [END howto_operator_start_java_job] + jar_to_local = GCSToLocalOperator( + task_id="jar-to-local", + bucket=GCS_JAR_BUCKET_NAME, + object_name=GCS_JAR_OBJECT_NAME, + filename="/tmp/dataflow-{{ ds_nodash }}.jar", + ) + + start_java_job_local = DataflowCreateJavaJobOperator( + task_id="start-java-job-local", + jar="/tmp/dataflow-{{ ds_nodash }}.jar", + job_name='{{task.task_id}}', + options={ + 'output': GCS_OUTPUT, + }, + poll_sleep=10, + job_class='org.apache.beam.examples.WordCount', + check_if_running=CheckJobRunning.WaitForRun, + ) + jar_to_local >> start_java_job_local + # [START howto_operator_start_python_job] start_python_job = DataflowCreatePythonJobOperator( task_id="start-python-job", + py_file=GCS_PYTHON, + py_options=[], + job_name='{{task.task_id}}', + options={ + 'output': GCS_OUTPUT, + }, + py_requirements=[ + 'apache-beam[gcp]>=2.14.0' + ], + py_interpreter='python3', + py_system_site_packages=False + ) + # [END howto_operator_start_python_job] + + start_python_job_local = DataflowCreatePythonJobOperator( + task_id="start-python-job-local", py_file='apache_beam.examples.wordcount', py_options=['-m'], job_name='{{task.task_id}}', @@ -79,7 +122,6 @@ py_interpreter='python3', py_system_site_packages=False ) - # [END howto_operator_start_python_job] start_template_job = DataflowTemplatedJobStartOperator( task_id="start-template-job", diff --git a/airflow/providers/google/cloud/hooks/gcs.py b/airflow/providers/google/cloud/hooks/gcs.py index 7dfdb87f6f..5c8e4a1492 100644 --- a/airflow/providers/google/cloud/hooks/gcs.py +++ b/airflow/providers/google/cloud/hooks/gcs.py @@ -19,13 +19,16 @@ """ This module contains a Google Cloud Storage hook. """ +import functools import gzip as gz import os import shutil import warnings +from contextlib import contextmanager from io import BytesIO from os import path -from typing import Optional, Set, Tuple, Union +from tempfile import NamedTemporaryFile +from typing import Optional, Set, Tuple, TypeVar, Union from urllib.parse import urlparse from google.api_core.exceptions import NotFound @@ -35,6 +38,70 @@ from airflow.providers.google.common.hooks.base_google import GoogleBaseHook from airflow.version import version +RT = TypeVar('RT') # pylint: disable=invalid-name + + +def _fallback_object_url_to_object_name_and_bucket_name( + object_url_keyword_arg_name='object_url', + bucket_name_keyword_arg_name='bucket_name', + object_name_keyword_arg_name='object_name', +): + """ + Decorator factory that convert object URL parameter to object name and bucket name parameter. + + :param object_url_keyword_arg_name: Name of the object URL parameter + :type object_url_keyword_arg_name: str + :param bucket_name_keyword_arg_name: Name of the bucket name parameter + :type bucket_name_keyword_arg_name: str + :param object_name_keyword_arg_name: Name of the object name parameter + :type object_name_keyword_arg_name: str + :return: Decorator + """ + def _wrapper(func): + + @functools.wraps(func) + def _inner_wrapper(self: "GCSHook", * args, **kwargs) -> RT: + if args: + raise AirflowException( + "You must use keyword arguments in this methods rather than positional") + + object_url = kwargs.get(object_url_keyword_arg_name) + bucket_name = kwargs.get(bucket_name_keyword_arg_name) + object_name = kwargs.get(object_name_keyword_arg_name) + + if object_url and bucket_name and object_name: + raise AirflowException( + "The mutually exclusive parameters. `object_url`, `bucket_name` together " + "with `object_name` parameters are present. " + "Please provide `object_url` or `bucket_name` and `object_name`." + ) + if object_url: + bucket_name, object_name = _parse_gcs_url(object_url) + kwargs[bucket_name_keyword_arg_name] = bucket_name + kwargs[object_name_keyword_arg_name] = object_name + del kwargs[object_url_keyword_arg_name] + + if not object_name or not bucket_name: + raise TypeError( + f"{func.__name__}() missing 2 required positional arguments: " + f"'{bucket_name_keyword_arg_name}' and '{object_name_keyword_arg_name}' " + f"or {object_url_keyword_arg_name}" + ) + if not object_name: + raise TypeError( + f"{func.__name__}() missing 1 required positional argument: " + f"'{object_name_keyword_arg_name}'" + ) + if not bucket_name: + raise TypeError( + f"{func.__name__}() missing 1 required positional argument: " + f"'{bucket_name_keyword_arg_name}'" + ) + + return func(self, *args, **kwargs) + return _inner_wrapper + return _wrapper + class GCSHook(GoogleBaseHook): """ @@ -200,6 +267,36 @@ def download(self, bucket_name, object_name, filename=None): else: return blob.download_as_string() + @_fallback_object_url_to_object_name_and_bucket_name() + @contextmanager + def provide_file( + self, + bucket_name: Optional[str] = None, + object_name: Optional[str] = None, + object_url: Optional[str] = None + ): + """ + Downloads the file to a temporary directory and returns a file handle + + You can use this method by passing the bucket_name and object_name parameters + or just object_url parameter. + + :param bucket_name: The bucket to fetch from. + :type bucket_name: str + :param object_name: The object to fetch. + :type object_name: str + :param object_url: File reference url. Must start with "gs: //" + :type object_url: str + :return: File handler + """ + if object_name is None: + raise ValueError("Object name can not be empty") + _, _, file_name = object_name.rpartition("/") + with NamedTemporaryFile(suffix=file_name) as tmp_file: + self.download(bucket_name=bucket_name, object_name=object_name, filename=tmp_file.name) + tmp_file.flush() + yield tmp_file + def upload(self, bucket_name: str, object_name: str, filename: Optional[str] = None, data: Optional[Union[str, bytes]] = None, mime_type: Optional[str] = None, gzip: bool = False, encoding: str = 'utf-8') -> None: @@ -877,7 +974,7 @@ def _prepare_sync_plan( return to_copy_blobs, to_delete_blobs, to_rewrite_blobs -def _parse_gcs_url(gsurl): +def _parse_gcs_url(gsurl: str) -> Tuple[str, str]: """ Given a Google Cloud Storage URL (gs:///), returns a tuple containing the corresponding bucket and blob. @@ -886,8 +983,10 @@ def _parse_gcs_url(gsurl): parsed_url = urlparse(gsurl) if not parsed_url.netloc: raise AirflowException('Please provide a bucket name') - else: - bucket = parsed_url.netloc - # Remove leading '/' but NOT trailing one - blob = parsed_url.path.lstrip('/') - return bucket, blob + if parsed_url.scheme.lower() != "gs": + raise AirflowException(f"Schema must be to 'gs://': Current schema: '{parsed_url.scheme}://'") + + bucket = parsed_url.netloc + # Remove leading '/' but NOT trailing one + blob = parsed_url.path.lstrip('/') + return bucket, blob diff --git a/airflow/providers/google/cloud/operators/dataflow.py b/airflow/providers/google/cloud/operators/dataflow.py index ed8a603e1e..171a52dc46 100644 --- a/airflow/providers/google/cloud/operators/dataflow.py +++ b/airflow/providers/google/cloud/operators/dataflow.py @@ -20,10 +20,8 @@ """ import copy -import os import re -import tempfile -import uuid +from contextlib import ExitStack from enum import Enum from typing import List, Optional @@ -221,23 +219,27 @@ def execute(self, context): name=self.job_name, variables=dataflow_options, project_id=self.project_id) if not is_running: - bucket_helper = GoogleCloudBucketHelper( - self.gcp_conn_id, self.delegate_to) - self.jar = bucket_helper.google_cloud_to_local(self.jar) - - def set_current_job_id(job_id): - self.job_id = job_id - - self.hook.start_java_dataflow( - job_name=self.job_name, - variables=dataflow_options, - jar=self.jar, - job_class=self.job_class, - append_job_name=True, - multiple_jobs=self.multiple_jobs, - on_new_job_id_callback=set_current_job_id, - project_id=self.project_id, - ) + with ExitStack() as exit_stack: + if self.jar.lower().startswith('gs://'): + gcs_hook = GCSHook(self.gcp_conn_id, self.delegate_to) + tmp_gcs_file = exit_stack.enter_context( # pylint: disable=no-member + gcs_hook.provide_file(object_url=self.jar) + ) + self.jar = tmp_gcs_file.name + + def set_current_job_id(job_id): + self.job_id = job_id + + self.hook.start_java_dataflow( + job_name=self.job_name, + variables=dataflow_options, + jar=self.jar, + job_class=self.job_class, + append_job_name=True, + multiple_jobs=self.multiple_jobs, + on_new_job_id_callback=set_current_job_id, + project_id=self.project_id, + ) def on_kill(self) -> None: self.log.info("On kill.") @@ -471,85 +473,42 @@ def __init__( # pylint: disable=too-many-arguments def execute(self, context): """Execute the python dataflow job.""" - bucket_helper = GoogleCloudBucketHelper( - self.gcp_conn_id, self.delegate_to) - self.py_file = bucket_helper.google_cloud_to_local(self.py_file) - self.hook = DataflowHook( - gcp_conn_id=self.gcp_conn_id, - delegate_to=self.delegate_to, - poll_sleep=self.poll_sleep - ) - dataflow_options = self.dataflow_default_options.copy() - dataflow_options.update(self.options) - # Convert argument names from lowerCamelCase to snake case. - camel_to_snake = lambda name: re.sub( - r'[A-Z]', lambda x: '_' + x.group(0).lower(), name) - formatted_options = {camel_to_snake(key): dataflow_options[key] - for key in dataflow_options} + with ExitStack() as exit_stack: + if self.py_file.lower().startswith('gs://'): + gcs_hook = GCSHook(self.gcp_conn_id, self.delegate_to) + tmp_gcs_file = exit_stack.enter_context( # pylint: disable=no-member + gcs_hook.provide_file(object_url=self.py_file) + ) + self.py_file = tmp_gcs_file.name + + self.hook = DataflowHook( + gcp_conn_id=self.gcp_conn_id, + delegate_to=self.delegate_to, + poll_sleep=self.poll_sleep + ) + dataflow_options = self.dataflow_default_options.copy() + dataflow_options.update(self.options) + # Convert argument names from lowerCamelCase to snake case. + camel_to_snake = lambda name: re.sub(r'[A-Z]', lambda x: '_' + x.group(0).lower(), name) + formatted_options = {camel_to_snake(key): dataflow_options[key] + for key in dataflow_options} - def set_current_job_id(job_id): - self.job_id = job_id + def set_current_job_id(job_id): + self.job_id = job_id - self.hook.start_python_dataflow( - job_name=self.job_name, - variables=formatted_options, - dataflow=self.py_file, - py_options=self.py_options, - py_interpreter=self.py_interpreter, - py_requirements=self.py_requirements, - py_system_site_packages=self.py_system_site_packages, - on_new_job_id_callback=set_current_job_id, - project_id=self.project_id, - ) + self.hook.start_python_dataflow( + job_name=self.job_name, + variables=formatted_options, + dataflow=self.py_file, + py_options=self.py_options, + py_interpreter=self.py_interpreter, + py_requirements=self.py_requirements, + py_system_site_packages=self.py_system_site_packages, + on_new_job_id_callback=set_current_job_id, + project_id=self.project_id, + ) def on_kill(self) -> None: self.log.info("On kill.") if self.job_id: self.hook.cancel_job(job_id=self.job_id, project_id=self.project_id) - - -class GoogleCloudBucketHelper: - """GoogleCloudStorageHook helper class to download GCS object.""" - GCS_PREFIX_LENGTH = 5 - - def __init__(self, - gcp_conn_id: str = 'google_cloud_default', - delegate_to: Optional[str] = None) -> None: - self._gcs_hook = GCSHook(gcp_conn_id, delegate_to) - - def google_cloud_to_local(self, file_name: str) -> str: - """ - Checks whether the file specified by file_name is stored in Google Cloud - Storage (GCS), if so, downloads the file and saves it locally. The full - path of the saved file will be returned. Otherwise the local file_name - will be returned immediately. - - :param file_name: The full path of input file. - :type file_name: str - :return: The full path of local file. - :rtype: str - """ - if not file_name.startswith('gs://'): - return file_name - - # Extracts bucket_id and object_id by first removing 'gs://' prefix and - # then split the remaining by path delimiter '/'. - path_components = file_name[self.GCS_PREFIX_LENGTH:].split('/') - if len(path_components) < 2: - raise Exception( - 'Invalid Google Cloud Storage (GCS) object path: {}' - .format(file_name)) - - bucket_id = path_components[0] - object_id = '/'.join(path_components[1:]) - local_file = os.path.join( - tempfile.gettempdir(), - 'dataflow{}-{}'.format(str(uuid.uuid4())[:8], path_components[-1]) - ) - self._gcs_hook.download(bucket_id, object_id, local_file) - - if os.stat(local_file).st_size > 0: - return local_file - raise Exception( - 'Failed to download Google Cloud Storage (GCS) object: {}' - .format(file_name)) diff --git a/tests/providers/google/cloud/hooks/test_gcs.py b/tests/providers/google/cloud/hooks/test_gcs.py index 5bd2c25bc0..8a84483952 100644 --- a/tests/providers/google/cloud/hooks/test_gcs.py +++ b/tests/providers/google/cloud/hooks/test_gcs.py @@ -19,6 +19,7 @@ import copy import io import os +import re import tempfile import unittest from datetime import datetime, timedelta @@ -29,6 +30,7 @@ from airflow.exceptions import AirflowException from airflow.providers.google.cloud.hooks import gcs +from airflow.providers.google.cloud.hooks.gcs import _fallback_object_url_to_object_name_and_bucket_name from airflow.utils import timezone from airflow.version import version from tests.providers.google.cloud.utils.base_gcp_mock import mock_base_gcp_hook_default_project_id @@ -53,6 +55,8 @@ def test_parse_gcs_url(self): # invalid URI self.assertRaises(AirflowException, gcs._parse_gcs_url, 'gs:/bucket/path/to/blob') + self.assertRaises(AirflowException, gcs._parse_gcs_url, + 'http://google.com/aaa') # trailing slash self.assertEqual( @@ -64,6 +68,54 @@ def test_parse_gcs_url(self): gcs._parse_gcs_url('gs://bucket/'), ('bucket', '')) +class TestFallbackObjectUrlToObjectNameAndBucketName(unittest.TestCase): + def setUp(self) -> None: + self.assertion_on_body = mock.MagicMock() + + @_fallback_object_url_to_object_name_and_bucket_name() + def test_method( + _, + bucket_name=None, + object_name=None, + object_url=None + ): + assert object_name == "OBJECT_NAME" + assert bucket_name == "BUCKET_NAME" + assert object_url is None + self.assertion_on_body() + self.test_method = test_method + + def test_should_url(self): + self.test_method(None, object_url="gs://BUCKET_NAME/OBJECT_NAME") + self.assertion_on_body.assert_called_once() + + def test_should_support_bucket_and_object(self): + self.test_method(None, bucket_name="BUCKET_NAME", object_name="OBJECT_NAME") + self.assertion_on_body.assert_called_once() + + def test_should_raise_exception_on_missing(self): + with self.assertRaisesRegex( + TypeError, + re.escape( + "test_method() missing 2 required positional arguments: 'bucket_name' and 'object_name'" + )): + self.test_method(None) + self.assertion_on_body.assert_not_called() + + def test_should_raise_exception_on_mutually_exclusive(self): + with self.assertRaisesRegex( + AirflowException, + re.escape("The mutually exclusive parameters.") + ): + self.test_method( + None, + bucket_name="BUCKET_NAME", + object_name="OBJECT_NAME", + object_url="gs://BUCKET_NAME/OBJECT_NAME" + ) + self.assertion_on_body.assert_not_called() + + class TestGCSHook(unittest.TestCase): def setUp(self): with mock.patch( @@ -658,6 +710,37 @@ def test_download_to_file(self, mock_service): self.assertEqual(response, test_file) download_filename_method.assert_called_once_with(test_file) + @mock.patch(GCS_STRING.format('NamedTemporaryFile')) + @mock.patch(GCS_STRING.format('GCSHook.get_conn')) + def test_provide_file(self, mock_service, mock_temp_file): + test_bucket = 'test_bucket' + test_object = 'test_object' + test_object_bytes = io.BytesIO(b"input") + test_file = 'test_file' + + download_filename_method = mock_service.return_value.bucket.return_value \ + .blob.return_value.download_to_filename + download_filename_method.return_value = None + + download_as_a_string_method = mock_service.return_value.bucket.return_value \ + .blob.return_value.download_as_string + download_as_a_string_method.return_value = test_object_bytes + mock_temp_file.return_value.__enter__.return_value = mock.MagicMock() + mock_temp_file.return_value.__enter__.return_value.name = test_file + + with self.gcs_hook.provide_file( + bucket_name=test_bucket, + object_name=test_object) as response: + + self.assertEqual(test_file, response.name) + download_filename_method.assert_called_once_with(test_file) + mock_temp_file.assert_has_calls([ + mock.call(suffix='test_object'), + mock.call().__enter__(), + mock.call().__enter__().flush(), + mock.call().__exit__(None, None, None) + ]) + class TestGCSHookUpload(unittest.TestCase): def setUp(self): diff --git a/tests/providers/google/cloud/operators/test_dataflow.py b/tests/providers/google/cloud/operators/test_dataflow.py index 246d53d585..c76ed39734 100644 --- a/tests/providers/google/cloud/operators/test_dataflow.py +++ b/tests/providers/google/cloud/operators/test_dataflow.py @@ -23,7 +23,7 @@ from airflow.providers.google.cloud.operators.dataflow import ( CheckJobRunning, DataflowCreateJavaJobOperator, DataflowCreatePythonJobOperator, - DataflowTemplatedJobStartOperator, GoogleCloudBucketHelper, + DataflowTemplatedJobStartOperator, ) from airflow.version import version @@ -36,7 +36,7 @@ } PY_FILE = 'gs://my-bucket/my-object.py' PY_INTERPRETER = 'python3' -JAR_FILE = 'example/test.jar' +JAR_FILE = 'gs://my-bucket/example/test.jar' JOB_CLASS = 'com.test.NotMain' PY_OPTIONS = ['-m'] DEFAULT_OPTIONS_PYTHON = DEFAULT_OPTIONS_JAVA = { @@ -88,14 +88,14 @@ def test_init(self): EXPECTED_ADDITIONAL_OPTIONS) @mock.patch('airflow.providers.google.cloud.operators.dataflow.DataflowHook') - @mock.patch(GCS_HOOK_STRING.format('GoogleCloudBucketHelper')) + @mock.patch('airflow.providers.google.cloud.operators.dataflow.GCSHook') def test_exec(self, gcs_hook, dataflow_mock): """Test DataflowHook is created and the right args are passed to start_python_workflow. """ start_python_hook = dataflow_mock.return_value.start_python_dataflow - gcs_download_hook = gcs_hook.return_value.google_cloud_to_local + gcs_provide_file = gcs_hook.return_value.provide_file self.dataflow.execute(None) self.assertTrue(dataflow_mock.called) expected_options = { @@ -104,7 +104,7 @@ def test_exec(self, gcs_hook, dataflow_mock): 'output': 'gs://test/output', 'labels': {'foo': 'bar', 'airflow-version': TEST_VERSION} } - gcs_download_hook.assert_called_once_with(PY_FILE) + gcs_provide_file.assert_called_once_with(object_url=PY_FILE) start_python_hook.assert_called_once_with( job_name=JOB_NAME, variables=expected_options, @@ -145,18 +145,18 @@ def test_init(self): self.assertEqual(self.dataflow.check_if_running, CheckJobRunning.WaitForRun) @mock.patch('airflow.providers.google.cloud.operators.dataflow.DataflowHook') - @mock.patch(GCS_HOOK_STRING.format('GoogleCloudBucketHelper')) + @mock.patch('airflow.providers.google.cloud.operators.dataflow.GCSHook') def test_exec(self, gcs_hook, dataflow_mock): """Test DataflowHook is created and the right args are passed to start_java_workflow. """ start_java_hook = dataflow_mock.return_value.start_java_dataflow - gcs_download_hook = gcs_hook.return_value.google_cloud_to_local + gcs_provide_file = gcs_hook.return_value.provide_file self.dataflow.check_if_running = CheckJobRunning.IgnoreJob self.dataflow.execute(None) self.assertTrue(dataflow_mock.called) - gcs_download_hook.assert_called_once_with(JAR_FILE) + gcs_provide_file.assert_called_once_with(object_url=JAR_FILE) start_java_hook.assert_called_once_with( job_name=JOB_NAME, variables=mock.ANY, @@ -169,7 +169,7 @@ def test_exec(self, gcs_hook, dataflow_mock): ) @mock.patch('airflow.providers.google.cloud.operators.dataflow.DataflowHook') - @mock.patch(GCS_HOOK_STRING.format('GoogleCloudBucketHelper')) + @mock.patch('airflow.providers.google.cloud.operators.dataflow.GCSHook') def test_check_job_running_exec(self, gcs_hook, dataflow_mock): """Test DataflowHook is created and the right args are passed to start_java_workflow. @@ -178,17 +178,17 @@ def test_check_job_running_exec(self, gcs_hook, dataflow_mock): dataflow_running = dataflow_mock.return_value.is_job_dataflow_running dataflow_running.return_value = True start_java_hook = dataflow_mock.return_value.start_java_dataflow - gcs_download_hook = gcs_hook.return_value.google_cloud_to_local + gcs_provide_file = gcs_hook.return_value.provide_file self.dataflow.check_if_running = True self.dataflow.execute(None) self.assertTrue(dataflow_mock.called) - gcs_download_hook.assert_not_called() + gcs_provide_file.assert_not_called() start_java_hook.assert_not_called() dataflow_running.assert_called_once_with( name=JOB_NAME, variables=mock.ANY, project_id=None) @mock.patch('airflow.providers.google.cloud.operators.dataflow.DataflowHook') - @mock.patch(GCS_HOOK_STRING.format('GoogleCloudBucketHelper')) + @mock.patch('airflow.providers.google.cloud.operators.dataflow.GCSHook') def test_check_job_not_running_exec(self, gcs_hook, dataflow_mock): """Test DataflowHook is created and the right args are passed to start_java_workflow with option to check if job is running @@ -197,11 +197,11 @@ def test_check_job_not_running_exec(self, gcs_hook, dataflow_mock): dataflow_running = dataflow_mock.return_value.is_job_dataflow_running dataflow_running.return_value = False start_java_hook = dataflow_mock.return_value.start_java_dataflow - gcs_download_hook = gcs_hook.return_value.google_cloud_to_local + gcs_provide_file = gcs_hook.return_value.provide_file self.dataflow.check_if_running = True self.dataflow.execute(None) self.assertTrue(dataflow_mock.called) - gcs_download_hook.assert_called_once_with(JAR_FILE) + gcs_provide_file.assert_called_once_with(object_url=JAR_FILE) start_java_hook.assert_called_once_with( job_name=JOB_NAME, variables=mock.ANY, @@ -216,7 +216,7 @@ def test_check_job_not_running_exec(self, gcs_hook, dataflow_mock): name=JOB_NAME, variables=mock.ANY, project_id=None) @mock.patch('airflow.providers.google.cloud.operators.dataflow.DataflowHook') - @mock.patch(GCS_HOOK_STRING.format('GoogleCloudBucketHelper')) + @mock.patch('airflow.providers.google.cloud.operators.dataflow.GCSHook') def test_check_multiple_job_exec(self, gcs_hook, dataflow_mock): """Test DataflowHook is created and the right args are passed to start_java_workflow with option to check multiple jobs @@ -225,12 +225,12 @@ def test_check_multiple_job_exec(self, gcs_hook, dataflow_mock): dataflow_running = dataflow_mock.return_value.is_job_dataflow_running dataflow_running.return_value = False start_java_hook = dataflow_mock.return_value.start_java_dataflow - gcs_download_hook = gcs_hook.return_value.google_cloud_to_local + gcs_provide_file = gcs_hook.return_value.provide_file self.dataflow.multiple_jobs = True self.dataflow.check_if_running = True self.dataflow.execute(None) self.assertTrue(dataflow_mock.called) - gcs_download_hook.assert_called_once_with(JAR_FILE) + gcs_provide_file.assert_called_once_with(object_url=JAR_FILE) start_java_hook.assert_called_once_with( job_name=JOB_NAME, variables=mock.ANY, @@ -289,72 +289,3 @@ def test_exec(self, dataflow_mock): on_new_job_id_callback=mock.ANY, project_id=None, ) - - -class TestGoogleCloudBucketHelper(unittest.TestCase): - - @mock.patch( - 'airflow.providers.google.cloud.operators.dataflow.GoogleCloudBucketHelper.__init__' - ) - def test_invalid_object_path(self, mock_parent_init): - # This is just the path of a bucket hence invalid filename - file_name = 'gs://test-bucket' - mock_parent_init.return_value = None - - gcs_bucket_helper = GoogleCloudBucketHelper() - gcs_bucket_helper._gcs_hook = mock.Mock() - - with self.assertRaises(Exception) as context: - gcs_bucket_helper.google_cloud_to_local(file_name) - - self.assertEqual( - 'Invalid Google Cloud Storage (GCS) object path: {}'.format(file_name), - str(context.exception)) - - @mock.patch( - 'airflow.providers.google.cloud.operators.dataflow.GoogleCloudBucketHelper.__init__' - ) - def test_valid_object(self, mock_parent_init): - file_name = 'gs://test-bucket/path/to/obj.jar' - mock_parent_init.return_value = None - - gcs_bucket_helper = GoogleCloudBucketHelper() - gcs_bucket_helper._gcs_hook = mock.Mock() - - # pylint: disable=redefined-builtin,unused-argument - def _mock_download(bucket, object, filename=None): - text_file_contents = 'text file contents' - with open(filename, 'w') as text_file: - text_file.write(text_file_contents) - return text_file_contents - - gcs_bucket_helper._gcs_hook.download.side_effect = _mock_download - - local_file = gcs_bucket_helper.google_cloud_to_local(file_name) - self.assertIn('obj.jar', local_file) - - @mock.patch( - 'airflow.providers.google.cloud.operators.dataflow.GoogleCloudBucketHelper.__init__' - ) - def test_empty_object(self, mock_parent_init): - file_name = 'gs://test-bucket/path/to/obj.jar' - mock_parent_init.return_value = None - - gcs_bucket_helper = GoogleCloudBucketHelper() - gcs_bucket_helper._gcs_hook = mock.Mock() - - # pylint: disable=redefined-builtin,unused-argument - def _mock_download(bucket, object, filename=None): - text_file_contents = '' - with open(filename, 'w') as text_file: - text_file.write(text_file_contents) - return text_file_contents - - gcs_bucket_helper._gcs_hook.download.side_effect = _mock_download - - with self.assertRaises(Exception) as context: - gcs_bucket_helper.google_cloud_to_local(file_name) - - self.assertEqual( - 'Failed to download Google Cloud Storage (GCS) object: {}'.format(file_name), - str(context.exception))