Skip to content

Commit

Permalink
Reuse re.Pattern object in regex patterns (#1318)
Browse files Browse the repository at this point in the history
  • Loading branch information
sydney-runkle authored Jun 11, 2024
1 parent 8afaa45 commit d7946da
Show file tree
Hide file tree
Showing 4 changed files with 63 additions and 27 deletions.
22 changes: 11 additions & 11 deletions generate_self_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,11 @@
import decimal
import importlib.util
import re
import sys
from collections.abc import Callable
from datetime import date, datetime, time, timedelta
from pathlib import Path
from typing import TYPE_CHECKING, Any, Dict, ForwardRef, List, Set, Type, Union
from typing import TYPE_CHECKING, Any, Dict, ForwardRef, List, Pattern, Set, Type, Union

from typing_extensions import TypedDict, get_args, get_origin, is_typeddict

Expand Down Expand Up @@ -46,7 +47,7 @@
schema_ref_validator = {'type': 'definition-ref', 'schema_ref': 'root-schema'}


def get_schema(obj: Any, definitions: dict[str, core_schema.CoreSchema]) -> core_schema.CoreSchema:
def get_schema(obj: Any, definitions: dict[str, core_schema.CoreSchema]) -> core_schema.CoreSchema: # noqa: C901
if isinstance(obj, str):
return {'type': obj}
elif obj in (datetime, timedelta, date, time, bool, int, float, str, decimal.Decimal):
Expand Down Expand Up @@ -81,6 +82,9 @@ def get_schema(obj: Any, definitions: dict[str, core_schema.CoreSchema]) -> core
elif issubclass(origin, Type):
# can't really use 'is-instance' since this is used for the class_ parameter of 'is-instance' validators
return {'type': 'any'}
elif origin in (Pattern, re.Pattern):
# can't really use 'is-instance' easily with Pattern, so we use `any` as a placeholder for now
return {'type': 'any'}
else:
# debug(obj)
raise TypeError(f'Unknown type: {obj!r}')
Expand Down Expand Up @@ -189,16 +193,12 @@ def all_literal_values(type_: type[core_schema.Literal]) -> list[any]:


def eval_forward_ref(type_: Any) -> Any:
try:
try:
# Python 3.12+
return type_._evaluate(core_schema.__dict__, None, type_params=set(), recursive_guard=set())
except TypeError:
# Python 3.9+
return type_._evaluate(core_schema.__dict__, None, set())
except TypeError:
# for Python 3.8
if sys.version_info < (3, 9):
return type_._evaluate(core_schema.__dict__, None)
elif sys.version_info < (3, 12, 4):
return type_._evaluate(core_schema.__dict__, None, recursive_guard=set())
else:
return type_._evaluate(core_schema.__dict__, None, type_params=set(), recursive_guard=set())


def main() -> None:
Expand Down
6 changes: 3 additions & 3 deletions python/pydantic_core/core_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from collections.abc import Mapping
from datetime import date, datetime, time, timedelta
from decimal import Decimal
from typing import TYPE_CHECKING, Any, Callable, Dict, Hashable, List, Set, Tuple, Type, Union
from typing import TYPE_CHECKING, Any, Callable, Dict, Hashable, List, Pattern, Set, Tuple, Type, Union

from typing_extensions import deprecated

Expand Down Expand Up @@ -744,7 +744,7 @@ def decimal_schema(

class StringSchema(TypedDict, total=False):
type: Required[Literal['str']]
pattern: str
pattern: Union[str, Pattern[str]]
max_length: int
min_length: int
strip_whitespace: bool
Expand All @@ -760,7 +760,7 @@ class StringSchema(TypedDict, total=False):

def str_schema(
*,
pattern: str | None = None,
pattern: str | Pattern[str] | None = None,
max_length: int | None = None,
min_length: int | None = None,
strip_whitespace: bool | None = None,
Expand Down
55 changes: 42 additions & 13 deletions src/validators/string.rs
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ impl StrConstrainedValidator {
.map(|s| s.to_str())
.transpose()?
.unwrap_or(RegexEngine::RUST_REGEX);
Pattern::compile(py, s, regex_engine)
Pattern::compile(s, regex_engine)
})
.transpose()?;
let min_length: Option<usize> =
Expand Down Expand Up @@ -230,18 +230,47 @@ impl RegexEngine {
}

impl Pattern {
fn compile(py: Python<'_>, pattern: String, engine: &str) -> PyResult<Self> {
let engine = match engine {
RegexEngine::RUST_REGEX => {
RegexEngine::RustRegex(Regex::new(&pattern).map_err(|e| py_schema_error_type!("{}", e))?)
}
RegexEngine::PYTHON_RE => {
let re_compile = py.import_bound(intern!(py, "re"))?.getattr(intern!(py, "compile"))?;
RegexEngine::PythonRe(re_compile.call1((&pattern,))?.into())
}
_ => return Err(py_schema_error_type!("Invalid regex engine: {}", engine)),
};
Ok(Self { pattern, engine })
fn extract_pattern_str(pattern: &Bound<'_, PyAny>) -> PyResult<String> {
if pattern.is_instance_of::<PyString>() {
Ok(pattern.to_string())
} else {
pattern
.getattr("pattern")
.and_then(|attr| attr.extract::<String>())
.map_err(|_| py_schema_error_type!("Invalid pattern, must be str or re.Pattern: {}", pattern))
}
}

fn compile(pattern: Bound<'_, PyAny>, engine: &str) -> PyResult<Self> {
let pattern_str = Self::extract_pattern_str(&pattern)?;

let py = pattern.py();

let re_module = py.import_bound(intern!(py, "re"))?;
let re_compile = re_module.getattr(intern!(py, "compile"))?;
let re_pattern = re_module.getattr(intern!(py, "Pattern"))?;

if pattern.is_instance(&re_pattern)? {
// if the pattern is already a compiled regex object, we default to using the python re engine
// so that any flags, etc. are preserved
Ok(Self {
pattern: pattern_str,
engine: RegexEngine::PythonRe(pattern.to_object(py)),
})
} else {
let engine = match engine {
RegexEngine::RUST_REGEX => {
RegexEngine::RustRegex(Regex::new(&pattern_str).map_err(|e| py_schema_error_type!("{}", e))?)
}
RegexEngine::PYTHON_RE => RegexEngine::PythonRe(re_compile.call1((pattern,))?.into()),
_ => return Err(py_schema_error_type!("Invalid regex engine: {}", engine)),
};

Ok(Self {
pattern: pattern_str,
engine,
})
}
}

fn is_match(&self, py: Python<'_>, target: &str) -> PyResult<bool> {
Expand Down
7 changes: 7 additions & 0 deletions tests/validators/test_string.py
Original file line number Diff line number Diff line change
Expand Up @@ -398,3 +398,10 @@ def test_coerce_numbers_to_str_schema_with_strict_mode(number: int):
v.validate_python(number)
with pytest.raises(ValidationError):
v.validate_json(str(number))


@pytest.mark.parametrize('engine', [None, 'rust-regex', 'python-re'])
def test_compiled_regex(engine) -> None:
v = SchemaValidator(core_schema.str_schema(pattern=re.compile('abc', re.IGNORECASE), regex_engine=engine))
assert v.validate_python('abc') == 'abc'
assert v.validate_python('ABC') == 'ABC'

0 comments on commit d7946da

Please sign in to comment.