Skip to content

Commit

Permalink
Fix broken regex for allowed_deserialization_classes (#36147)
Browse files Browse the repository at this point in the history
---------

Co-authored-by: Victor Dominguite <[email protected]>
Co-authored-by: Elad Kalif <[email protected]>
(cherry picked from commit 20cb70b)
  • Loading branch information
tobiaszorzetto authored and potiuk committed Feb 13, 2024
1 parent 7331df8 commit 9d9d62e
Show file tree
Hide file tree
Showing 5 changed files with 101 additions and 12 deletions.
15 changes: 12 additions & 3 deletions airflow/config_templates/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -241,11 +241,20 @@ core:
allowed_deserialization_classes:
description: |
What classes can be imported during deserialization. This is a multi line value.
The individual items will be parsed as regexp. Python built-in classes (like dict)
are always allowed. Bare "." will be replaced so you can set airflow.* .
The individual items will be parsed as a pattern to a glob function.
Python built-in classes (like dict) are always allowed.
version_added: 2.5.0
type: string
default: 'airflow\..*'
default: 'airflow.*'
example: ~
allowed_deserialization_classes_regexp:
description: |
What classes can be imported during deserialization. This is a multi line value.
The individual items will be parsed as regexp patterns.
This is a secondary option to ``allowed_deserialization_classes``.
version_added: 2.8.1
type: string
default: ''
example: ~
killed_task_cleanup_time:
description: |
Expand Down
2 changes: 1 addition & 1 deletion airflow/config_templates/unit_tests.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ unit_test_mode = True
# We want to use a shorter timeout for task cleanup
killed_task_cleanup_time = 5
# We only allow our own classes to be deserialized in tests
allowed_deserialization_classes = airflow\..* tests\..*
allowed_deserialization_classes = airflow.* tests.*

[database]

Expand Down
27 changes: 23 additions & 4 deletions airflow/serialization/serde.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import functools
import logging
import sys
from fnmatch import fnmatch
from importlib import import_module
from typing import TYPE_CHECKING, Any, Pattern, TypeVar, Union, cast

Expand Down Expand Up @@ -241,7 +242,6 @@ def deserialize(o: T | None, full=True, type_hint: Any = None) -> object:
# only return string representation
if not full:
return _stringify(classname, version, value)

if not _match(classname) and classname not in _extra_allowed:
raise ImportError(
f"{classname} was not found in allow list for deserialization imports. "
Expand Down Expand Up @@ -288,7 +288,22 @@ def _convert(old: dict) -> dict:


def _match(classname: str) -> bool:
return any(p.match(classname) is not None for p in _get_patterns())
"""Checks if the given classname matches a path pattern either using glob format or regexp format."""
return _match_glob(classname) or _match_regexp(classname)


@functools.lru_cache(maxsize=None)
def _match_glob(classname: str):
"""Checks if the given classname matches a pattern from allowed_deserialization_classes using glob syntax."""
patterns = _get_patterns()
return any(fnmatch(classname, p.pattern) for p in patterns)


@functools.lru_cache(maxsize=None)
def _match_regexp(classname: str):
"""Checks if the given classname matches a pattern from allowed_deserialization_classes_regexp using regexp."""
patterns = _get_regexp_patterns()
return any(p.match(classname) is not None for p in patterns)


def _stringify(classname: str, version: int, value: T | None) -> str:
Expand Down Expand Up @@ -359,8 +374,12 @@ def _register():

@functools.lru_cache(maxsize=None)
def _get_patterns() -> list[Pattern]:
patterns = conf.get("core", "allowed_deserialization_classes").split()
return [re2.compile(re2.sub(r"(\w)\.", r"\1\..", p)) for p in patterns]
return [re2.compile(p) for p in conf.get("core", "allowed_deserialization_classes").split()]


@functools.lru_cache(maxsize=None)
def _get_regexp_patterns() -> list[Pattern]:
return [re2.compile(p) for p in conf.get("core", "allowed_deserialization_classes_regexp").split()]


_register()
11 changes: 11 additions & 0 deletions newsfragments/36147.significant.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
The ``allowed_deserialization_classes`` flag now follows a glob pattern.

For example if one wants to add the class ``airflow.tests.custom_class`` to the
``allowed_deserialization_classes`` list, it can be done by writing the full class
name (``airflow.tests.custom_class``) or a pattern such as the ones used in glob
search (e.g., ``airflow.*``, ``airflow.tests.*``).

If you currently use a custom regexp path make sure to rewrite it as a glob pattern.

Alternatively, if you still wish to match it as a regexp pattern, add it under the new
list ``allowed_deserialization_classes_regexp`` instead.
58 changes: 54 additions & 4 deletions tests/serialization/test_serde.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,10 @@
SCHEMA_ID,
VERSION,
_get_patterns,
_get_regexp_patterns,
_match,
_match_glob,
_match_regexp,
deserialize,
serialize,
)
Expand All @@ -44,10 +47,16 @@
@pytest.fixture()
def recalculate_patterns():
_get_patterns.cache_clear()
_get_regexp_patterns.cache_clear()
_match_glob.cache_clear()
_match_regexp.cache_clear()
try:
yield
finally:
_get_patterns.cache_clear()
_get_regexp_patterns.cache_clear()
_match_glob.cache_clear()
_match_regexp.cache_clear()


class Z:
Expand Down Expand Up @@ -218,7 +227,7 @@ def test_serder_dataclass(self):

@conf_vars(
{
("core", "allowed_deserialization_classes"): "airflow[.].*",
("core", "allowed_deserialization_classes"): "airflow.*",
}
)
@pytest.mark.usefixtures("recalculate_patterns")
Expand All @@ -232,13 +241,54 @@ def test_allow_list_for_imports(self):

@conf_vars(
{
("core", "allowed_deserialization_classes"): "tests.*",
("core", "allowed_deserialization_classes"): "tests.airflow.*",
}
)
@pytest.mark.usefixtures("recalculate_patterns")
def test_allow_list_replace(self):
def test_allow_list_match(self):
assert _match("tests.airflow.deep")
assert _match("testsfault") is False
assert _match("tests.wrongpath") is False

@conf_vars(
{
("core", "allowed_deserialization_classes"): "tests.airflow.deep",
}
)
@pytest.mark.usefixtures("recalculate_patterns")
def test_allow_list_match_class(self):
"""Test the match function when passing a full classname as
allowed_deserialization_classes
"""
assert _match("tests.airflow.deep")
assert _match("tests.airflow.FALSE") is False

@conf_vars(
{
("core", "allowed_deserialization_classes"): "",
("core", "allowed_deserialization_classes_regexp"): "tests\.airflow\..",
}
)
@pytest.mark.usefixtures("recalculate_patterns")
def test_allow_list_match_regexp(self):
"""Test the match function when passing a path as
allowed_deserialization_classes_regexp with no glob pattern defined
"""
assert _match("tests.airflow.deep")
assert _match("tests.wrongpath") is False

@conf_vars(
{
("core", "allowed_deserialization_classes"): "",
("core", "allowed_deserialization_classes_regexp"): "tests\.airflow\.deep",
}
)
@pytest.mark.usefixtures("recalculate_patterns")
def test_allow_list_match_class_regexp(self):
"""Test the match function when passing a full classname as
allowed_deserialization_classes_regexp with no glob pattern defined
"""
assert _match("tests.airflow.deep")
assert _match("tests.airflow.FALSE") is False

def test_incompatible_version(self):
data = dict(
Expand Down

0 comments on commit 9d9d62e

Please sign in to comment.