Skip to content

Commit

Permalink
fix(providers/sftp): respect soft_fail argument when exception is rai…
Browse files Browse the repository at this point in the history
…sed (#34169)
  • Loading branch information
Lee-W authored Sep 7, 2023
1 parent 0ecbbac commit f5857a9
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 3 deletions.
8 changes: 8 additions & 0 deletions airflow/providers/sftp/sensors/sftp.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@

from paramiko.sftp import SFTP_NO_SUCH_FILE

from airflow.exceptions import AirflowSkipException
from airflow.providers.sftp.hooks.sftp import SFTPHook
from airflow.sensors.base import BaseSensorOperator, PokeReturnValue
from airflow.utils.timezone import convert_to_utc
Expand Down Expand Up @@ -84,24 +85,31 @@ def poke(self, context: Context) -> PokeReturnValue | bool:
return False
else:
actual_files_to_check = [self.path]

for actual_file_to_check in actual_files_to_check:
try:
mod_time = self.hook.get_mod_time(actual_file_to_check)
self.log.info("Found File %s last modified: %s", actual_file_to_check, mod_time)
except OSError as e:
if e.errno != SFTP_NO_SUCH_FILE:
# TODO: remove this if block when min_airflow_version is set to higher than 2.7.1
if self.soft_fail:
raise AirflowSkipException from e
raise e
continue

if self.newer_than:
_mod_time = convert_to_utc(datetime.strptime(mod_time, "%Y%m%d%H%M%S"))
_newer_than = convert_to_utc(self.newer_than)
if _newer_than <= _mod_time:
files_found.append(actual_file_to_check)
else:
files_found.append(actual_file_to_check)

self.hook.close_conn()
if not len(files_found):
return False

if self.python_callable is not None:
if self.op_kwargs:
self.op_kwargs["files_found"] = files_found
Expand Down
12 changes: 9 additions & 3 deletions tests/providers/sftp/sensors/test_sftp.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from paramiko.sftp import SFTP_FAILURE, SFTP_NO_SUCH_FILE
from pendulum import datetime as pendulum_datetime, timezone

from airflow.exceptions import AirflowSkipException
from airflow.providers.sftp.sensors.sftp import SFTPSensor
from airflow.sensors.base import PokeReturnValue

Expand All @@ -48,12 +49,17 @@ def test_file_absent(self, sftp_hook_mock):
sftp_hook_mock.return_value.get_mod_time.assert_called_once_with("/path/to/file/1970-01-01.txt")
assert not output

@pytest.mark.parametrize(
"soft_fail, expected_exception", ((False, OSError), (True, AirflowSkipException))
)
@patch("airflow.providers.sftp.sensors.sftp.SFTPHook")
def test_sftp_failure(self, sftp_hook_mock):
def test_sftp_failure(self, sftp_hook_mock, soft_fail: bool, expected_exception):
sftp_hook_mock.return_value.get_mod_time.side_effect = OSError(SFTP_FAILURE, "SFTP failure")
sftp_sensor = SFTPSensor(task_id="unit_test", path="/path/to/file/1970-01-01.txt")
sftp_sensor = SFTPSensor(
task_id="unit_test", path="/path/to/file/1970-01-01.txt", soft_fail=soft_fail
)
context = {"ds": "1970-01-01"}
with pytest.raises(OSError):
with pytest.raises(expected_exception):
sftp_sensor.poke(context)
sftp_hook_mock.return_value.get_mod_time.assert_called_once_with("/path/to/file/1970-01-01.txt")

Expand Down

0 comments on commit f5857a9

Please sign in to comment.