Skip to content

Commit

Permalink
Fix str subclass validation for enums (#1273)
Browse files Browse the repository at this point in the history
Co-authored-by: David Hewitt <[email protected]>
  • Loading branch information
sydney-runkle and davidhewitt authored May 21, 2024
1 parent b777774 commit 727deee
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 3 deletions.
24 changes: 21 additions & 3 deletions src/validators/enum_.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use std::marker::PhantomData;
use pyo3::exceptions::PyTypeError;
use pyo3::intern;
use pyo3::prelude::*;
use pyo3::types::{PyDict, PyList, PyType};
use pyo3::types::{PyDict, PyFloat, PyInt, PyList, PyString, PyType};

use crate::build_tools::{is_strict, py_schema_err};
use crate::errors::{ErrorType, ValError, ValResult};
Expand Down Expand Up @@ -167,9 +167,27 @@ impl EnumValidateValue for PlainEnumValidator {
py: Python<'py>,
input: &I,
lookup: &LiteralLookup<PyObject>,
_strict: bool,
strict: bool,
) -> ValResult<Option<PyObject>> {
Ok(lookup.validate(py, input)?.map(|(_, v)| v.clone_ref(py)))
match lookup.validate(py, input)? {
Some((_, v)) => Ok(Some(v.clone_ref(py))),
None => {
if !strict {
if let Some(py_input) = input.as_python() {
// necessary for compatibility with 2.6, where str and int subclasses are allowed
if py_input.is_instance_of::<PyString>() {
return Ok(lookup.validate_str(input, false)?.map(|v| v.clone_ref(py)));
} else if py_input.is_instance_of::<PyInt>() {
return Ok(lookup.validate_int(py, input, false)?.map(|v| v.clone_ref(py)));
// necessary for compatibility with 2.6, where float values are allowed for int enums in lax mode
} else if py_input.is_instance_of::<PyFloat>() {
return Ok(lookup.validate_int(py, input, false)?.map(|v| v.clone_ref(py)));
}
}
}
Ok(None)
}
}
}
}

Expand Down
46 changes: 46 additions & 0 deletions tests/validators/test_enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,52 @@ class MyEnum(Enum):
SchemaValidator(core_schema.enum_schema(MyEnum, []))


def test_enum_with_str_subclass() -> None:
class MyEnum(Enum):
a = 'a'
b = 'b'

v = SchemaValidator(core_schema.enum_schema(MyEnum, list(MyEnum.__members__.values())))

assert v.validate_python(MyEnum.a) is MyEnum.a
assert v.validate_python('a') is MyEnum.a

class MyStr(str):
pass

assert v.validate_python(MyStr('a')) is MyEnum.a
with pytest.raises(ValidationError):
v.validate_python(MyStr('a'), strict=True)


def test_enum_with_int_subclass() -> None:
class MyEnum(Enum):
a = 1
b = 2

v = SchemaValidator(core_schema.enum_schema(MyEnum, list(MyEnum.__members__.values())))

assert v.validate_python(MyEnum.a) is MyEnum.a
assert v.validate_python(1) is MyEnum.a

class MyInt(int):
pass

assert v.validate_python(MyInt(1)) is MyEnum.a
with pytest.raises(ValidationError):
v.validate_python(MyInt(1), strict=True)


def test_validate_float_for_int_enum() -> None:
class MyEnum(int, Enum):
a = 1
b = 2

v = SchemaValidator(core_schema.enum_schema(MyEnum, list(MyEnum.__members__.values())))

assert v.validate_python(1.0) is MyEnum.a


def test_missing_error_converted_to_val_error() -> None:
class MyFlags(IntFlag):
OFF = 0
Expand Down

0 comments on commit 727deee

Please sign in to comment.