Skip to content

Commit

Permalink
[AIRFLOW-3905] Allow using "parameters" in SqlSensor (#4723)
Browse files Browse the repository at this point in the history
* [AIRFLOW-3905] Allow 'parameters' in SqlSensor

* Add check on conn_type & add test

Not all SQL-related connections are supported by SqlSensor,
due to limitation in Connection model and hook implementation.
  • Loading branch information
XD-DENG authored and ashb committed Mar 7, 2019
1 parent eb9c402 commit 4f27739
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 9 deletions.
23 changes: 18 additions & 5 deletions airflow/sensors/sql_sensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

from builtins import str

from airflow.exceptions import AirflowException
from airflow.hooks.base_hook import BaseHook
from airflow.sensors.base_sensor_operator import BaseSensorOperator
from airflow.utils.decorators import apply_defaults
Expand All @@ -33,22 +34,34 @@ class SqlSensor(BaseSensorOperator):
:type conn_id: str
:param sql: The sql to run. To pass, it needs to return at least one cell
that contains a non-zero / empty string value.
:type sql: str
:param parameters: The parameters to render the SQL query with (optional).
:type parameters: mapping or iterable
"""
template_fields = ('sql',)
template_ext = ('.hql', '.sql',)
ui_color = '#7c7287'

@apply_defaults
def __init__(self, conn_id, sql, *args, **kwargs):
self.sql = sql
def __init__(self, conn_id, sql, parameters=None, *args, **kwargs):
self.conn_id = conn_id
self.sql = sql
self.parameters = parameters
super(SqlSensor, self).__init__(*args, **kwargs)

def poke(self, context):
hook = BaseHook.get_connection(self.conn_id).get_hook()
conn = BaseHook.get_connection(self.conn_id)

allowed_conn_type = {'google_cloud_platform', 'jdbc', 'mssql',
'mysql', 'oracle', 'postgres',
'presto', 'sqlite', 'vertica'}
if conn.conn_type not in allowed_conn_type:
raise AirflowException("The connection type is not supported by SqlSensor. " +
"Supported connection types: {}".format(list(allowed_conn_type)))
hook = conn.get_hook()

self.log.info('Poking: %s', self.sql)
records = hook.get_records(self.sql)
self.log.info('Poking: %s (with parameters %s)', self.sql, self.parameters)
records = hook.get_records(self.sql, self.parameters)
if not records:
return False
else:
Expand Down
39 changes: 35 additions & 4 deletions tests/sensors/test_sql_sensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

from airflow import DAG
from airflow import configuration
from airflow.exceptions import AirflowException
from airflow.sensors.sql_sensor import SqlSensor
from airflow.utils.timezone import datetime

Expand All @@ -39,27 +40,56 @@ def setUp(self):
}
self.dag = DAG(TEST_DAG_ID, default_args=args)

def test_unsupported_conn_type(self):
t = SqlSensor(
task_id='sql_sensor_check',
conn_id='redis_default',
sql="SELECT count(1) FROM INFORMATION_SCHEMA.TABLES",
dag=self.dag
)

with self.assertRaises(AirflowException):
t.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)

@unittest.skipUnless(
'mysql' in configuration.conf.get('core', 'sql_alchemy_conn'), "this is a mysql test")
def test_sql_sensor_mysql(self):
t = SqlSensor(
t1 = SqlSensor(
task_id='sql_sensor_check',
conn_id='mysql_default',
sql="SELECT count(1) FROM INFORMATION_SCHEMA.TABLES",
dag=self.dag
)
t.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)
t1.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)

t2 = SqlSensor(
task_id='sql_sensor_check',
conn_id='mysql_default',
sql="SELECT count(%s) FROM INFORMATION_SCHEMA.TABLES",
parameters=["table_name"],
dag=self.dag
)
t2.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)

@unittest.skipUnless(
'postgresql' in configuration.conf.get('core', 'sql_alchemy_conn'), "this is a postgres test")
def test_sql_sensor_postgres(self):
t = SqlSensor(
t1 = SqlSensor(
task_id='sql_sensor_check',
conn_id='postgres_default',
sql="SELECT count(1) FROM INFORMATION_SCHEMA.TABLES",
dag=self.dag
)
t.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)
t1.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)

t2 = SqlSensor(
task_id='sql_sensor_check',
conn_id='postgres_default',
sql="SELECT count(%s) FROM INFORMATION_SCHEMA.TABLES",
parameters=["table_name"],
dag=self.dag
)
t2.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)

@mock.patch('airflow.sensors.sql_sensor.BaseHook')
def test_sql_sensor_postgres_poke(self, mock_hook):
Expand All @@ -69,6 +99,7 @@ def test_sql_sensor_postgres_poke(self, mock_hook):
sql="SELECT 1",
)

mock_hook.get_connection('postgres_default').conn_type = "postgres"
mock_get_records = mock_hook.get_connection.return_value.get_hook.return_value.get_records

mock_get_records.return_value = []
Expand Down

0 comments on commit 4f27739

Please sign in to comment.