From ed02392cb4d13d9c54670c92dc709ae810f59e14 Mon Sep 17 00:00:00 2001 From: Hofer-Julian Date: Wed, 13 Mar 2024 09:54:05 +0100 Subject: [PATCH 1/3] Fix model plotting while still ensuring that Edge tables are written Fixes #1236 --- python/ribasim/ribasim/geometry/edge.py | 6 +----- python/ribasim/ribasim/geometry/node.py | 2 +- python/ribasim/ribasim/input_base.py | 6 +++--- 3 files changed, 5 insertions(+), 9 deletions(-) diff --git a/python/ribasim/ribasim/geometry/edge.py b/python/ribasim/ribasim/geometry/edge.py index 7bb69e944..e905933c3 100644 --- a/python/ribasim/ribasim/geometry/edge.py +++ b/python/ribasim/ribasim/geometry/edge.py @@ -8,7 +8,7 @@ from geopandas import GeoDataFrame from matplotlib.axes import Axes from numpy.typing import NDArray -from pandera.typing import DataFrame, Series +from pandera.typing import Series from pandera.typing.geopandas import GeoSeries from shapely.geometry import LineString, MultiLineString, Point @@ -42,10 +42,6 @@ class Config: class EdgeTable(SpatialTableModel[EdgeSchema]): """Defines the connections between nodes.""" - def __init__(self, **kwargs): - kwargs.setdefault("df", DataFrame[EdgeSchema]()) - super().__init__(**kwargs) - def add( self, from_node: NodeData, diff --git a/python/ribasim/ribasim/geometry/node.py b/python/ribasim/ribasim/geometry/node.py index 72f610745..330cc809b 100644 --- a/python/ribasim/ribasim/geometry/node.py +++ b/python/ribasim/ribasim/geometry/node.py @@ -15,7 +15,7 @@ class NodeSchema(pa.SchemaModel): - node_id: Series[int] + node_id: Series[int] = pa.Field(nullable=False, default=0) name: Series[str] = pa.Field(default="") node_type: Series[str] = pa.Field(default="") subnetwork_id: Series[pd.Int64Dtype] = pa.Field( diff --git a/python/ribasim/ribasim/input_base.py b/python/ribasim/ribasim/input_base.py index c76ca6cbd..b08160a60 100644 --- a/python/ribasim/ribasim/input_base.py +++ b/python/ribasim/ribasim/input_base.py @@ -235,7 +235,7 @@ def _save( if self.df is not None and self.filepath is not None: self.sort() self._write_arrow(self.filepath, directory, input_dir) - elif self.df is not None and db_path is not None: + elif db_path is not None: self.sort() self._write_table(db_path) @@ -358,8 +358,8 @@ def _write_table(self, path: FilePath) -> None: ---------- path : FilePath """ - - gdf = gpd.GeoDataFrame(data=self.df) + df = DataFrame[self.tableschema()]() if self.df is None else self.df # type:ignore + gdf = gpd.GeoDataFrame(data=df) gdf = gdf.set_geometry("geometry") gdf.to_file(path, layer=self.tablename(), driver="GPKG", mode="a") From b0565a8ae63e327f56a0a5ed04940f9c76eff08a Mon Sep 17 00:00:00 2001 From: Hofer-Julian Date: Wed, 13 Mar 2024 10:09:10 +0100 Subject: [PATCH 2/3] Return instead of crash when trying to plot model without edges --- python/ribasim/ribasim/geometry/edge.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python/ribasim/ribasim/geometry/edge.py b/python/ribasim/ribasim/geometry/edge.py index e905933c3..710503338 100644 --- a/python/ribasim/ribasim/geometry/edge.py +++ b/python/ribasim/ribasim/geometry/edge.py @@ -79,7 +79,8 @@ def get_where_edge_type(self, edge_type: str) -> NDArray[np.bool_]: return (self.df.edge_type == edge_type).to_numpy() def plot(self, **kwargs) -> Axes: - assert self.df is not None # Pleases mypy + if self.df is None: + return kwargs = kwargs.copy() # Avoid side-effects ax = kwargs.get("ax", None) color_flow = kwargs.pop("color_flow", None) From 3d354cdd3f3e68cdd71eeb56630f5fed8669c236 Mon Sep 17 00:00:00 2001 From: Hofer-Julian Date: Wed, 13 Mar 2024 10:12:21 +0100 Subject: [PATCH 3/3] Fix type hint --- python/ribasim/ribasim/geometry/edge.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/ribasim/ribasim/geometry/edge.py b/python/ribasim/ribasim/geometry/edge.py index 710503338..cfbd91b74 100644 --- a/python/ribasim/ribasim/geometry/edge.py +++ b/python/ribasim/ribasim/geometry/edge.py @@ -78,9 +78,9 @@ def get_where_edge_type(self, edge_type: str) -> NDArray[np.bool_]: assert self.df is not None return (self.df.edge_type == edge_type).to_numpy() - def plot(self, **kwargs) -> Axes: + def plot(self, **kwargs) -> Axes | None: if self.df is None: - return + return None kwargs = kwargs.copy() # Avoid side-effects ax = kwargs.get("ax", None) color_flow = kwargs.pop("color_flow", None)