Skip to content

Commit

Permalink
Merge pull request adobe#41 from tamagoko/upgrade-pydantic
Browse files Browse the repository at this point in the history
upgrading to pydantic v2
  • Loading branch information
tamagoko authored Sep 15, 2023
2 parents baf2549 + 9abd65b commit a89f903
Show file tree
Hide file tree
Showing 6 changed files with 47 additions and 56 deletions.
3 changes: 1 addition & 2 deletions dysql/mappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,10 @@ class MapperError(Exception):


class DbMapResultBase(abc.ABC):
_key_columns = ['id']

@classmethod
def get_key_columns(cls):
return cls._key_columns
return ['id']

@classmethod
def create_instance(cls, *args, **kwargs) -> 'DbMapResultBase':
Expand Down
39 changes: 12 additions & 27 deletions dysql/pydantic_mappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,10 @@
NOTICE: Adobe permits you to use, modify, and distribute this file in accordance
with the terms of the Adobe license agreement accompanying it.
"""
import json
from json import JSONDecodeError
from typing import Any, Dict, Set

import sqlalchemy
from pydantic import BaseModel # pylint: disable=no-name-in-module
from pydantic.error_wrappers import ValidationError, ErrorWrapper
from pydantic import BaseModel, TypeAdapter

from .mappers import DbMapResultBase

Expand Down Expand Up @@ -44,19 +41,15 @@ class DbMapResultModel(BaseModel, DbMapResultBase):
@classmethod
def create_instance(cls, *args, **kwargs) -> 'DbMapResultModel':
# Uses the construct method to prevent validation when mapping results
return cls.construct(*args, **kwargs)
return cls.model_construct(*args, **kwargs)

def _map_json(self, current_dict: dict, record: sqlalchemy.engine.Row, field: str):
model_field = self.__fields__[field]
model_field = self.model_fields[field]
value = record[field]
if not value:
return
if not self._has_been_mapped():
try:
potential_json_data = record[field]
if potential_json_data:
current_dict[field] = json.loads(record[field])
except JSONDecodeError as exc:
return ErrorWrapper(ValueError(
f'Invalid JSON given to {model_field.alias}', exc), loc=model_field.alias)
return None
current_dict[field] = TypeAdapter(model_field.annotation).validate_json(value)

def _map_list(self, current_dict: dict, record: sqlalchemy.engine.Row, field: str):
if record[field] is None:
Expand Down Expand Up @@ -98,11 +91,9 @@ def _map_list_from_string(self, current_dict: dict, record: sqlalchemy.engine.Ro
list_string = str(list_string)
values_from_string = list(map(str.strip, list_string.split(',')))

model_field = self.__fields__[field]
model_field = self.model_fields[field]
# pre-validates the list we are expecting because we want to ensure all records are validated
values, errors_ = model_field.validate(values_from_string, current_dict, loc=model_field.alias)
if errors_:
raise ValidationError(errors_, DbMapResultModel)
values = TypeAdapter(model_field.annotation).validate_python(values_from_string)

if self._has_been_mapped() and current_dict[field]:
current_dict[field].extend(values)
Expand All @@ -122,17 +113,14 @@ def map_record(self, record: sqlalchemy.engine.Row) -> None:
- Remove all DB fields that are present in _dict_value_mappings since they were likely added above
:param record: the DB record
"""
errors = []
current_dict: dict = self.__dict__
for field in record.keys():
if field in self._list_fields:
self._map_list(current_dict, record, field)
elif field in self._csv_list_fields:
self._map_list_from_string(current_dict, record, field)
elif field in self._json_fields:
error = self._map_json(current_dict, record, field)
if error:
errors.append(error)
self._map_json(current_dict, record, field)
elif field in self._set_fields:
self._map_set(current_dict, record, field)
elif field in self._dict_key_fields:
Expand All @@ -142,12 +130,9 @@ def map_record(self, record: sqlalchemy.engine.Row) -> None:
if not self._has_been_mapped():
current_dict[field] = record[field]

if errors:
raise ValidationError(errors, DbMapResultModel)
# Remove all dict value fields (if present)
for db_field in self._dict_value_mappings.values():
current_dict.pop(db_field, None)

if self._has_been_mapped():
# At this point, just update the previous record
self.__dict__.update(current_dict)
Expand All @@ -156,7 +141,7 @@ def map_record(self, record: sqlalchemy.engine.Row) -> None:
self.__init__(**current_dict)

def raw(self) -> dict:
return self.dict()
return self.model_dump()

def has(self, field: str) -> bool:
return field in self.__dict__
Expand All @@ -166,7 +151,7 @@ def _has_been_mapped(self):
Tells if a record has already been mapped onto this class or not.
:return: True if map_record has already been called, False otherwise
"""
return bool(getattr(self, '__fields_set__', False))
return bool(self.model_fields_set)

def get(self, field: str, default: Any = None) -> Any:
return self.__dict__.get(field, default)
20 changes: 10 additions & 10 deletions dysql/test/test_mappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,15 +118,15 @@ def test_single_column_and_row(mapper):
])
def test_key_mapper_key_has_multiple(mapper, expected):
result = mapper.map_records([
TestRow(('column_named_something', 'column_with_some_value'), ['a', 1]),
TestRow(('column_named_something', 'column_with_some_value'), ['a', 2]),
TestRow(('column_named_something', 'column_with_some_value'), ['a', 3]),
TestRow(('column_named_something', 'column_with_some_value'), ['a', 4]),
TestRow(('column_named_something', 'column_with_some_value'), ['b', 3]),
TestRow(('column_named_something', 'column_with_some_value'), ['b', 4]),
TestRow(('column_named_something', 'column_with_some_value'), ['b', 5]),
TestRow(('column_named_something', 'column_with_some_value'), ['b', 6]),
TestRow(('column_named_something', 'column_with_some_value'), ['b', 7]),
HelperRow(('column_named_something', 'column_with_some_value'), ['a', 1]),
HelperRow(('column_named_something', 'column_with_some_value'), ['a', 2]),
HelperRow(('column_named_something', 'column_with_some_value'), ['a', 3]),
HelperRow(('column_named_something', 'column_with_some_value'), ['a', 4]),
HelperRow(('column_named_something', 'column_with_some_value'), ['b', 3]),
HelperRow(('column_named_something', 'column_with_some_value'), ['b', 4]),
HelperRow(('column_named_something', 'column_with_some_value'), ['b', 5]),
HelperRow(('column_named_something', 'column_with_some_value'), ['b', 6]),
HelperRow(('column_named_something', 'column_with_some_value'), ['b', 7]),
])
assert len(result) == len(expected)
assert result == expected
Expand All @@ -137,7 +137,7 @@ def test_key_mapper_key_value_same():
KeyValueMapper(key_column='same', value_column='same')


class TestRow: # pylint: disable=too-few-public-methods
class HelperRow: # pylint: disable=too-few-public-methods
"""
Helper class does the most basic functionality we see when accessing records passed in
"""
Expand Down
35 changes: 21 additions & 14 deletions dysql/test/test_pydantic_mappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from typing import Any, Dict, List, Set, Optional
import pytest

from pydantic.error_wrappers import ValidationError
from pydantic import ValidationError

from dysql import (
RecordCombiningMapper,
Expand Down Expand Up @@ -46,7 +46,7 @@ class ListWithStringsModel(DbMapResultModel):
_csv_list_fields: Set[str] = {'list1', 'list2'}

id: int
list1: Optional[List[str]]
list1: Optional[List[str]] = None
list2: List[int] = [] # help test empty list gets filled


Expand All @@ -55,11 +55,15 @@ class JsonModel(DbMapResultModel):

id: int
json1: dict
json2: Optional[dict]
json2: Optional[dict] = None


class MultiKeyModel(DbMapResultModel):
_key_columns = ['a', 'b']

@classmethod
def get_key_columns(cls):
return ['a', 'b']

_list_fields = {'c'}
a: int
b: str
Expand Down Expand Up @@ -238,7 +242,7 @@ def test_csv_list_field_without_mapping_ignored():

def test_csv_list_field_invalid_type():
mapper = RecordCombiningMapper(record_mapper=ListWithStringsModel)
with pytest.raises(ValidationError, match="value is not a valid integer"):
with pytest.raises(ValidationError, match="1 validation error for list"):
mapper.map_records([{
'id': 1,
'list1': 'a,b',
Expand All @@ -265,7 +269,7 @@ def test_json_field():
}
}
})
}]).dict() == {
}]).model_dump() == {
'id': 1,
'json1': {
'a': 1,
Expand All @@ -282,17 +286,20 @@ def test_json_field():
}


def test_invalid_json():
with pytest.raises(ValidationError) as excinfo:
@pytest.mark.parametrize('json1, json2', [
('{ "json": value', None),
('{ "json": value', '{ "json": value }'),
('{ "json": value }', '{ "json": value'),
(None, '{ "json": value'),
])
def test_invalid_json(json1, json2):
with pytest.raises(ValidationError, match='Invalid JSON'):
mapper = SingleRowMapper(record_mapper=JsonModel)
mapper.map_records([{
'id': 1,
'json1': '{ "json": value',
'json2': 'just a string'
'json1': json1,
'json2': json2
}])
assert len(excinfo.value.args[0]) == 2
assert excinfo.value.args[0][0].exc.args[0] == 'Invalid JSON given to json1'
assert excinfo.value.args[0][1].exc.args[0] == 'Invalid JSON given to json2'


def test_json_none():
Expand All @@ -301,7 +308,7 @@ def test_json_none():
'id': 1,
'json1': '{ "first": "value" }',
'json2': None
}]).dict() == {
}]).model_dump() == {
'id': 1,
'json1': {
'first': 'value',
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
-e .

# Used in development, and as an extra
pydantic>=1.8.2,<2
pydantic>2
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from setuptools import setup, find_packages


BASE_VERSION = '1.13'
BASE_VERSION = '2.0'
SOURCE_DIR = os.path.dirname(
os.path.abspath(__file__)
)
Expand Down Expand Up @@ -81,7 +81,7 @@ def get_version():
'sqlalchemy<2',
),
extras_require={
'pydantic': ['pydantic>=1.8.2,<2'],
'pydantic': ['pydantic>2'],
},
classifiers=[
'Development Status :: 4 - Beta',
Expand Down

0 comments on commit a89f903

Please sign in to comment.