Skip to content

Commit

Permalink
Validate connection IDs (#31140)
Browse files Browse the repository at this point in the history
* Validate connection_ids

(cherry picked from commit 5cb8ef8)
  • Loading branch information
mattusifer authored and Elad Kalif committed Jun 8, 2023
1 parent 9682a12 commit f033241
Show file tree
Hide file tree
Showing 4 changed files with 69 additions and 1 deletion.
3 changes: 2 additions & 1 deletion airflow/www/forms.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
from airflow.providers_manager import ProvidersManager
from airflow.utils import timezone
from airflow.utils.types import DagRunType
from airflow.www.validators import ValidKey
from airflow.www.widgets import (
AirflowDateTimePickerROWidget,
AirflowDateTimePickerWidget,
Expand Down Expand Up @@ -205,7 +206,7 @@ def _iter_connection_types() -> Iterator[tuple[str, str]]:
class ConnectionForm(DynamicForm):
conn_id = StringField(
lazy_gettext("Connection Id"),
validators=[InputRequired()],
validators=[InputRequired(), ValidKey()],
widget=BS3TextFieldWidget(),
)
conn_type = SelectField(
Expand Down
21 changes: 21 additions & 0 deletions airflow/www/validators.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@

from wtforms.validators import EqualTo, ValidationError

from airflow.utils import helpers


class GreaterEqualThan(EqualTo):
"""Compares the values of two fields.
Expand Down Expand Up @@ -76,3 +78,22 @@ def __call__(self, form, field):
except JSONDecodeError as ex:
message = self.message or f"JSON Validation Error: {ex}"
raise ValidationError(message=field.gettext(message.format(field.data)))


class ValidKey:
"""
Validates values that will be used as keys
:param max_length:
The maximum length of the given key
"""

def __init__(self, max_length=200):
self.max_length = max_length

def __call__(self, form, field):
if field.data:
try:
helpers.validate_key(field.data, self.max_length)
except Exception as e:
raise ValidationError(str(e))
35 changes: 35 additions & 0 deletions tests/www/test_validators.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,3 +120,38 @@ def test_validation_raises_custom_message(self):
self._validate(
message="Invalid JSON: {}",
)


class TestValidKey:
def setup_method(self):
self.form_field_mock = mock.MagicMock(data="valid_key")
self.form_field_mock.gettext.side_effect = lambda msg: msg
self.form_mock = mock.MagicMock(spec_set=dict)

def _validate(self):
validator = validators.ValidKey()

return validator(self.form_mock, self.form_field_mock)

def test_form_field_is_none(self):
self.form_field_mock.data = None

assert self._validate() is None

def test_validation_pass(self):
assert self._validate() is None

def test_validation_fails_with_trailing_whitespace(self):
self.form_field_mock.data = "invalid key "

with pytest.raises(validators.ValidationError):
self._validate()

def test_validation_fails_with_too_many_characters(self):
self.form_field_mock.data = "".join("x" for _ in range(1000))

with pytest.raises(
validators.ValidationError,
match=r"The key has to be less than [0-9]+ characters",
):
self._validate()
11 changes: 11 additions & 0 deletions tests/www/views/test_views_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,17 @@ def test_create_connection(admin_client, session):
_check_last_log(session, dag_id=None, event="connection.create", execution_date=None)


def test_invalid_connection_id_trailing_blanks(admin_client, session):
invalid_conn_id = "conn_id_with_trailing_blanks "
invalid_connection = {**CONNECTION, "conn_id": invalid_conn_id}
resp = admin_client.post("/connection/add", data=invalid_connection, follow_redirects=True)
check_content_in_response(
f"The key '{invalid_conn_id}' has to be made of alphanumeric characters, "
+ "dashes, dots and underscores exclusively",
resp,
)


def test_action_logging_connection_masked_secrets(session, admin_client):
admin_client.post("/connection/add", data=conn_with_extra(), follow_redirects=True)
_check_last_log_masked_connection(session, dag_id=None, event="connection.create", execution_date=None)
Expand Down

0 comments on commit f033241

Please sign in to comment.