Skip to content

Commit

Permalink
Allow creating objects from primitive objects
Browse files Browse the repository at this point in the history
  • Loading branch information
Peter-van-Tol committed Sep 22, 2024
1 parent d62f540 commit 01dcbc8
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 32 deletions.
66 changes: 34 additions & 32 deletions src/pydantic_shapely/annotations.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,51 +90,53 @@ def validate(self, value) -> BaseGeometry:
ValueError: If the input value is not a valid WKT-string or if the
supplied geometry is not of the expected type.
"""
# Test whether user supplied the geometry directly
# - Test whether user supplied the geometry directly
if isinstance(value, BaseGeometry):
geometry = value
# - convert a (WKT-) string to a object
elif isinstance(value, str):
try:
geometry: BaseGeometry = shapely.from_wkt(value)
except Exception as ex:
raise ValueError("Supplied string is not a valid WKT-string") from ex
# - last resort, pass the value to the constructor of shapely
else:
# - get the types that are supported by the field
if isclass(self.__geometry_type__):
# The geometry type is a class, check if the geometry is an instance of the class
if isinstance(value, self.__geometry_type__):
return self._validate_z_values(value)
raise ValueError(
f"Supplied geometry ({value.geom_type}) is not a "
f"{self.__geometry_type__.__name__}."
)
supported_types = [self.__geometry_type__]
else:
# The geometry type is a Union, check if the geometry is an instance of any of the
# classes
supported_types = typing.get_args(self.__geometry_type__)
if any(isinstance(value, t) for t in supported_types):
return self._validate_z_values(value)
# - for each type, check we can instantiate the geometry with the value
# from the field
for t in supported_types:
try:
geometry = t(value)
break
except Exception:
pass
else:
raise ValueError(
f"Supplied geometry ({value.geom_type}) is not one of the expected "
f"types: {', '.join([t.__name__ for t in supported_types])}."
f"Supplied value ({value}) cannot be converted to a valid geometry."
)

# Convert the geometry to a point, the geometry should be a valid WKT
try:
geometry: BaseGeometry = shapely.from_wkt(value)
except Exception as ex:
raise ValueError("Supplied string is not a valid WKT-string") from ex
if isclass(self.__geometry_type__):
# The geometry type is a class, check if the geometry is an instance of the class
if not isinstance(geometry, self.__geometry_type__):
raise ValueError(
f"Supplied geometry ({geometry.geom_type}) is not a "
f"{self.__geometry_type__.__name__}."
)
if isinstance(geometry, self.__geometry_type__):
return self._validate_z_values(geometry)
raise ValueError(
f"Supplied geometry ({geometry.geom_type}) is not a "
f"{self.__geometry_type__.__name__}."
)
else:
# The geometry type is a Union, check if the geometry is an instance of any of the
# classes
supported_types = typing.get_args(self.__geometry_type__)
if not any(isinstance(geometry, t) for t in supported_types):
raise ValueError(
f"Supplied geometry ({geometry.geom_type}) is not one of the expected "
f"types: {', '.join([t.__name__ for t in supported_types])}."
)
# Custom validations
geometry = self._validate_z_values(geometry)
return geometry
if any(isinstance(geometry, t) for t in supported_types):
return self._validate_z_values(geometry)
raise ValueError(
f"Supplied geometry ({geometry.geom_type}) is not one of the expected "
f"types: {', '.join([t.__name__ for t in supported_types])}."
)

@staticmethod
def serialize(value) -> str:
Expand Down
20 changes: 20 additions & 0 deletions tests/test_annotations.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,20 @@
")",
}


EXAMPLES_BASE = {
Point: (10, 20),
LineString: ((10, 10), (20, 20), (21, 30)),
Polygon: (((0, 0), (0, 40), (40, 40), (40, 0), (0, 0))),
MultiPoint: (Point(0, 0), Point(10, 20), Point(15, 20), Point(30, 30)),
MultiLineString: (LineString(((10, 10), (20, 20))), LineString((15, 15), (30, 15))),
MultiPolygon: (
Polygon([(10, 10), (10, 20), (20, 20), (20, 15), (10, 10)],),
Polygon([(60, 60), (70, 70), (80, 60), (60, 60)],)
)
}


EXAMPLES_OBJ_2D = {
Point: Point(10, 20),
LineString: LineString([(10, 10), (20, 20), (21, 30)]),
Expand Down Expand Up @@ -113,6 +127,12 @@ def test_annotation_correct_geom_type_roundtrip(geom_type):
instance = model(geometry=EXAMPLES_OBJ_2D[geom_type])
assert model.model_validate_json(instance.model_dump_json()) == instance

if geom_type != GeometryCollection:
instance = model(geometry=EXAMPLES_BASE[geom_type])
assert model.model_validate_json(instance.model_dump_json()) == instance
instance = model(geometry=EXAMPLES_WKT[geom_type])
assert model.model_validate_json(instance.model_dump_json()) == instance


@pytest.mark.parametrize(
"geom_type",
Expand Down

0 comments on commit 01dcbc8

Please sign in to comment.