Skip to content

Commit

Permalink
Fix GCSObjectExistenceSensor operator to return the same XCOM value i…
Browse files Browse the repository at this point in the history
…n deferrable and non-deferrable mode (#39206)
  • Loading branch information
VladaZakharova authored Apr 25, 2024
1 parent 8c556da commit 09f3446
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 12 deletions.
26 changes: 14 additions & 12 deletions airflow/providers/google/cloud/sensors/gcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def __init__(
self.object = object
self.use_glob = use_glob
self.google_cloud_conn_id = google_cloud_conn_id
self._matches: list[str] = []
self._matches: bool = False
self.impersonation_chain = impersonation_chain
self.retry = retry

Expand All @@ -101,17 +101,16 @@ def poke(self, context: Context) -> bool:
gcp_conn_id=self.google_cloud_conn_id,
impersonation_chain=self.impersonation_chain,
)
if self.use_glob:
self._matches = hook.list(self.bucket, match_glob=self.object)
return bool(self._matches)
else:
return hook.exists(self.bucket, self.object, self.retry)
self._matches = (
bool(hook.list(self.bucket, match_glob=self.object))
if self.use_glob
else hook.exists(self.bucket, self.object, self.retry)
)
return self._matches

def execute(self, context: Context) -> None:
def execute(self, context: Context):
"""Airflow runs this method on the worker and defers using the trigger."""
if not self.deferrable:
super().execute(context)
else:
if self.deferrable:
if not self.poke(context=context):
self.defer(
timeout=timedelta(seconds=self.timeout),
Expand All @@ -127,8 +126,11 @@ def execute(self, context: Context) -> None:
),
method_name="execute_complete",
)
else:
super().execute(context)
return self._matches

def execute_complete(self, context: Context, event: dict[str, str]) -> str:
def execute_complete(self, context: Context, event: dict[str, str]) -> bool:
"""
Act as a callback for when the trigger fires - returns immediately.
Expand All @@ -140,7 +142,7 @@ def execute_complete(self, context: Context, event: dict[str, str]) -> str:
raise AirflowSkipException(event["message"])
raise AirflowException(event["message"])
self.log.info("File %s was found in bucket %s.", self.object, self.bucket)
return event["message"]
return True


@deprecated(
Expand Down
30 changes: 30 additions & 0 deletions tests/providers/google/cloud/sensors/test_gcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,20 @@ def next_time_side_effect():


class TestGoogleCloudStorageObjectSensor:
@mock.patch("airflow.providers.google.cloud.sensors.gcs.GCSHook")
@mock.patch("airflow.providers.google.cloud.sensors.gcs.GCSObjectExistenceSensor.defer")
def test_gcs_object_existence_sensor_return_value(self, mock_defer, mock_hook):
task = GCSObjectExistenceSensor(
task_id="task-id",
bucket=TEST_BUCKET,
object=TEST_OBJECT,
google_cloud_conn_id=TEST_GCP_CONN_ID,
deferrable=True,
)
mock_hook.return_value.list.return_value = True
return_value = task.execute(mock.MagicMock())
assert return_value, True

@mock.patch("airflow.providers.google.cloud.sensors.gcs.GCSHook")
def test_should_pass_argument_to_hook(self, mock_hook):
task = GCSObjectExistenceSensor(
Expand Down Expand Up @@ -183,6 +197,22 @@ def test_gcs_object_existence_sensor_execute_complete(self):
task.execute_complete(context=None, event={"status": "success", "message": "Job completed"})
mock_log_info.assert_called_with("File %s was found in bucket %s.", TEST_OBJECT, TEST_BUCKET)

def test_gcs_object_existence_sensor_execute_complete_return_value(self):
"""Asserts that logging occurs as expected when deferrable is set to True"""
task = GCSObjectExistenceSensor(
task_id="task-id",
bucket=TEST_BUCKET,
object=TEST_OBJECT,
google_cloud_conn_id=TEST_GCP_CONN_ID,
deferrable=True,
)
with mock.patch.object(task.log, "info") as mock_log_info:
return_value = task.execute_complete(
context=None, event={"status": "success", "message": "Job completed"}
)
mock_log_info.assert_called_with("File %s was found in bucket %s.", TEST_OBJECT, TEST_BUCKET)
assert return_value, True


class TestGoogleCloudStorageObjectAsyncSensor:
depcrecation_message = (
Expand Down

0 comments on commit 09f3446

Please sign in to comment.