diff --git a/airflow/sensors/sql_sensor.py b/airflow/sensors/sql_sensor.py index de46e6d3db9b8..9a1b11fc50b1a 100644 --- a/airflow/sensors/sql_sensor.py +++ b/airflow/sensors/sql_sensor.py @@ -19,6 +19,7 @@ from builtins import str +from airflow.exceptions import AirflowException from airflow.hooks.base_hook import BaseHook from airflow.sensors.base_sensor_operator import BaseSensorOperator from airflow.utils.decorators import apply_defaults @@ -33,22 +34,34 @@ class SqlSensor(BaseSensorOperator): :type conn_id: str :param sql: The sql to run. To pass, it needs to return at least one cell that contains a non-zero / empty string value. + :type sql: str + :param parameters: The parameters to render the SQL query with (optional). + :type parameters: mapping or iterable """ template_fields = ('sql',) template_ext = ('.hql', '.sql',) ui_color = '#7c7287' @apply_defaults - def __init__(self, conn_id, sql, *args, **kwargs): - self.sql = sql + def __init__(self, conn_id, sql, parameters=None, *args, **kwargs): self.conn_id = conn_id + self.sql = sql + self.parameters = parameters super(SqlSensor, self).__init__(*args, **kwargs) def poke(self, context): - hook = BaseHook.get_connection(self.conn_id).get_hook() + conn = BaseHook.get_connection(self.conn_id) + + allowed_conn_type = {'google_cloud_platform', 'jdbc', 'mssql', + 'mysql', 'oracle', 'postgres', + 'presto', 'sqlite', 'vertica'} + if conn.conn_type not in allowed_conn_type: + raise AirflowException("The connection type is not supported by SqlSensor. " + + "Supported connection types: {}".format(list(allowed_conn_type))) + hook = conn.get_hook() - self.log.info('Poking: %s', self.sql) - records = hook.get_records(self.sql) + self.log.info('Poking: %s (with parameters %s)', self.sql, self.parameters) + records = hook.get_records(self.sql, self.parameters) if not records: return False else: diff --git a/tests/sensors/test_sql_sensor.py b/tests/sensors/test_sql_sensor.py index 81bcdd1691919..0b25d58056eda 100644 --- a/tests/sensors/test_sql_sensor.py +++ b/tests/sensors/test_sql_sensor.py @@ -21,6 +21,7 @@ from airflow import DAG from airflow import configuration +from airflow.exceptions import AirflowException from airflow.sensors.sql_sensor import SqlSensor from airflow.utils.timezone import datetime @@ -39,27 +40,56 @@ def setUp(self): } self.dag = DAG(TEST_DAG_ID, default_args=args) + def test_unsupported_conn_type(self): + t = SqlSensor( + task_id='sql_sensor_check', + conn_id='redis_default', + sql="SELECT count(1) FROM INFORMATION_SCHEMA.TABLES", + dag=self.dag + ) + + with self.assertRaises(AirflowException): + t.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) + @unittest.skipUnless( 'mysql' in configuration.conf.get('core', 'sql_alchemy_conn'), "this is a mysql test") def test_sql_sensor_mysql(self): - t = SqlSensor( + t1 = SqlSensor( task_id='sql_sensor_check', conn_id='mysql_default', sql="SELECT count(1) FROM INFORMATION_SCHEMA.TABLES", dag=self.dag ) - t.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) + t1.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) + + t2 = SqlSensor( + task_id='sql_sensor_check', + conn_id='mysql_default', + sql="SELECT count(%s) FROM INFORMATION_SCHEMA.TABLES", + parameters=["table_name"], + dag=self.dag + ) + t2.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) @unittest.skipUnless( 'postgresql' in configuration.conf.get('core', 'sql_alchemy_conn'), "this is a postgres test") def test_sql_sensor_postgres(self): - t = SqlSensor( + t1 = SqlSensor( task_id='sql_sensor_check', conn_id='postgres_default', sql="SELECT count(1) FROM INFORMATION_SCHEMA.TABLES", dag=self.dag ) - t.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) + t1.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) + + t2 = SqlSensor( + task_id='sql_sensor_check', + conn_id='postgres_default', + sql="SELECT count(%s) FROM INFORMATION_SCHEMA.TABLES", + parameters=["table_name"], + dag=self.dag + ) + t2.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) @mock.patch('airflow.sensors.sql_sensor.BaseHook') def test_sql_sensor_postgres_poke(self, mock_hook): @@ -69,6 +99,7 @@ def test_sql_sensor_postgres_poke(self, mock_hook): sql="SELECT 1", ) + mock_hook.get_connection('postgres_default').conn_type = "postgres" mock_get_records = mock_hook.get_connection.return_value.get_hook.return_value.get_records mock_get_records.return_value = []