Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

v0.7.0 #21

Merged
merged 4 commits into from
Oct 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions changelog.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,15 @@
# Changelog

## v0.7.0 (2024-10-22)

### Changed

- Aborting a script will now show the traceback

### Fixed

- Confit should no longer cause pydantic v1 deprecation warnings

## v0.6.0 (2024-09-13)

### Fixed
Expand Down
2 changes: 1 addition & 1 deletion confit/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,4 @@
VisibleDeprecationWarning,
)

__version__ = "0.6.0"
__version__ = "0.7.0"
2 changes: 2 additions & 0 deletions confit/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,8 @@ def command(ctx: Context, config: Optional[List[Path]] = None):
print("Validation error:", file=sys.stderr, end=" ")
print(str(e), file=sys.stderr)
sys.exit(1)
except KeyboardInterrupt as e: # pragma: no cover
raise Exception("Interrupted by user") from e

return validated

Expand Down
9 changes: 5 additions & 4 deletions confit/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from confit.utils.xjson import Reference

Loc = Tuple[Union[int, str]]
PYDANTIC_V1 = pydantic.VERSION.split(".")[0] == "1"


class MissingReference(Exception):
Expand Down Expand Up @@ -243,11 +244,11 @@ def patch_errors(
# field_model.vd.model, pydantic.BaseModel
# ):
# field_model = field_model.vd.model
if hasattr(field_model, "model_fields"):
field_model = field_model.model_fields[part]
else:
if PYDANTIC_V1:
field_model = field_model.__fields__[part]
if hasattr(field_model, "type_"):
else:
field_model = field_model.model_fields[part]
if PYDANTIC_V1:
field_model = field_model.type_
else:
field_model = field_model.annotation
Expand Down
23 changes: 13 additions & 10 deletions confit/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,11 +55,12 @@ def invoked(kw):
# "self" must be passed as a positional argument
if use_self:
kw = {**kw, self_name: resolved}
extras = [
key
for key in kw
if key not in pydantic_func.model.__fields__ and key != extras_name
]
fields = (
pydantic_func.model.__fields__
if PYDANTIC_V1
else pydantic_func.model.model_fields
)
extras = [key for key in kw if key not in fields and key != extras_name]
try:
model_instance = pydantic_func.model(
**{
Expand Down Expand Up @@ -183,10 +184,10 @@ def validate(_func: Callable) -> Callable:
else:
vd = ValidatedFunction(_func.__init__, config)
vd.model.__name__ = _func.__name__
if hasattr(vd.model, "model_fields"):
vd.model.model_fields["self"].default = None
else:
if PYDANTIC_V1:
vd.model.__fields__["self"].default = None
else:
vd.model.model_fields["self"].default = None

# This function is called by Pydantic when asked to cast
# a value (most likely a dict) as a Model (most often during
Expand Down Expand Up @@ -293,8 +294,10 @@ def wrapper_function(*args: Any, **kwargs: Any) -> Any:
raise e.with_traceback(remove_lib_from_traceback(e.__traceback__))

_func.vd = vd
_func.__get_validators__ = __get_validators__
_func.__get_pydantic_core_schema__ = __get_pydantic_core_schema__
if PYDANTIC_V1:
_func.__get_validators__ = __get_validators__
else:
_func.__get_pydantic_core_schema__ = __get_pydantic_core_schema__
# _func.model = vd.model
# _func.model.type_ = _func
_func.__init__ = wrapper_function
Expand Down
61 changes: 57 additions & 4 deletions tests/test_as_list.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,48 @@
from typing import Any, Generic, List, TypeVar
from dataclasses import is_dataclass
from typing import Generic, List, TypeVar

import pydantic
import pytest
from pydantic import BaseModel
from typing_extensions import is_typeddict

from confit import validate_arguments
from confit.errors import ConfitValidationError, patch_errors

T = TypeVar("T")
if pydantic.VERSION < "2":

def cast(type_, obj):
class Model(pydantic.BaseModel):
__root__: type_

class Config:
arbitrary_types_allowed = True

return Model(__root__=obj).__root__

else:
from pydantic.type_adapter import ConfigDict, TypeAdapter
from pydantic_core import core_schema

def make_type_adapter(type_):
config = None
if not issubclass(type, BaseModel) or is_dataclass(type) or is_typeddict(type):
config = ConfigDict(arbitrary_types_allowed=True)
return TypeAdapter(type_, config=config)

def cast(type_, obj):
return make_type_adapter(type_).validate_python(obj)


class MetaAsList(type):
def __init__(cls, name, bases, dct):
super().__init__(name, bases, dct)
cls.item = Any
cls.type_ = List

def __getitem__(self, item):
new_type = MetaAsList(self.__name__, (self,), {})
new_type.item = item
new_type.type_ = List[item]
return new_type

def validate(cls, value, config=None):
Expand All @@ -25,7 +51,7 @@ def validate(cls, value, config=None):
if not isinstance(value, list):
value = [value]
try:
return pydantic.parse_obj_as(List[cls.item], value)
return cast(cls.type_, value)
except pydantic.ValidationError as e:
e = patch_errors(e, drop_names=("__root__",))
e.model = cls
Expand All @@ -34,6 +60,9 @@ def validate(cls, value, config=None):
def __get_validators__(cls):
yield cls.validate

def __get_pydantic_core_schema__(cls, source, handler):
return core_schema.no_info_plain_validator_function(cls.validate)


class AsList(Generic[T], metaclass=MetaAsList):
pass
Expand All @@ -52,3 +81,27 @@ def func(a: AsList[int]):
assert (
"1 validation error for test_as_list.test_as_list.<locals>.func()\n" "-> a.0\n"
) in str(e.value)


class CustomMeta(type):
def __getattr__(self, item):
raise AttributeError(item)

def __dir__(self):
return super().__dir__()


class Custom:
def __init__(self, value: int):
self.value = value


def test_as_list_custom():
@validate_arguments
def func(a: AsList[Custom]):
return [x.value for x in a]

assert func(Custom(4)) == [4]

with pytest.raises(ConfitValidationError):
func({"data": "ok"})
Loading