Skip to content

Commit

Permalink
Order output tables based on datamodel
Browse files Browse the repository at this point in the history
Fixes #395 and tests it as well
  • Loading branch information
e-lo committed Oct 18, 2024
1 parent bb8a71d commit 63a1963
Show file tree
Hide file tree
Showing 6 changed files with 38 additions and 8 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/push.yml
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ jobs:
run: ruff check --output-format=github network_wrangler
- name: Run tests with coverage and benchmarking
run: |
pytest --junitxml=coverage.xml --benchmark-save=pr_benchmark --benchmark-json=pr_benchmark.json
pytest --junitxml=coverage.xml --benchmark-save=pr_benchmark --benchmark-json=pr_benchmark.json
- name: Store benchmark result
uses: benchmark-action/github-action-benchmark@v1
with:
Expand Down
5 changes: 2 additions & 3 deletions network_wrangler/roadway/links/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from ...models.roadway.tables import RoadLinksAttrs, RoadLinksTable, RoadNodesAttrs, RoadNodesTable
from ...params import LAT_LON_CRS
from ...utils.io_table import read_table, write_table
from ...utils.models import validate_call_pyd
from ...utils.models import order_fields_from_data_model, validate_call_pyd
from .create import data_to_links_df


Expand Down Expand Up @@ -90,8 +90,6 @@ def write_links(
overwrite: if True, will overwrite existing files. Defaults to False.
include_geometry: if True, will include geometry in the output. Defaults to False.
"""
# TODO write wrapper on validate call so don't have to do this
links_df.attrs.update(RoadLinksAttrs)
if not include_geometry and file_format == "geojson":
file_format = "json"

Expand All @@ -112,4 +110,5 @@ def write_links(
links_df = pd.DataFrame(links_df)
links_df = links_df.drop(columns=geo_cols)

links_df = order_fields_from_data_model(links_df, RoadLinksTable)
write_table(links_df, links_file, overwrite=overwrite)
5 changes: 2 additions & 3 deletions network_wrangler/roadway/nodes/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from ...models.roadway.tables import RoadNodesAttrs, RoadNodesTable
from ...params import LAT_LON_CRS
from ...utils.io_table import read_table, write_table
from ...utils.models import validate_call_pyd, validate_df_to_model
from ...utils.models import order_fields_from_data_model, validate_call_pyd, validate_df_to_model
from .create import data_to_nodes_df

if TYPE_CHECKING:
Expand Down Expand Up @@ -115,9 +115,8 @@ def write_nodes(
to "geojson".
overwrite: whether to overwrite existing nodes file. Defaults to True.
"""
# TODO write wrapper on validate call so don't have to do this
nodes_df.attrs.update(RoadNodesAttrs)
nodes_file = Path(out_dir) / f"{prefix}node.{file_format}"
nodes_df = order_fields_from_data_model(nodes_df, RoadNodesTable)
write_table(nodes_df, nodes_file, overwrite=overwrite)


Expand Down
8 changes: 7 additions & 1 deletion network_wrangler/roadway/shapes/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,12 @@
from ...models.roadway.tables import RoadShapesTable
from ...params import LAT_LON_CRS
from ...utils.io_table import read_table, write_table
from ...utils.models import empty_df_from_datamodel, validate_call_pyd, validate_df_to_model
from ...utils.models import (
empty_df_from_datamodel,
order_fields_from_data_model,
validate_call_pyd,
validate_df_to_model,
)
from .create import df_to_shapes_df


Expand Down Expand Up @@ -98,4 +103,5 @@ def write_shapes(
overwrite: whether to overwrite file if it exists.
"""
shapes_file = Path(out_dir) / f"{prefix}shape.{format}"
shapes_df = order_fields_from_data_model(shapes_df, RoadShapesTable)
write_table(shapes_df, shapes_file, overwrite=overwrite)
16 changes: 16 additions & 0 deletions network_wrangler/utils/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,3 +241,19 @@ def wrapper(*args, **kwargs):
return validated_func(*args, **kwargs)

return wrapper


def order_fields_from_data_model(df: pd.DataFrame, model: DataFrameModel) -> pd.DataFrame:
"""Order the fields in a DataFrame to match the order in a Pandera DataFrameModel.
Will add any fields that are not in the model to the end of the DataFrame.
Will not add any fields that are in the model but not in the DataFrame.
Args:
df: DataFrame to order.
model: Pandera DataFrameModel to order the DataFrame to.
"""
model_fields = list(model.__fields__.keys())
df_model_fields = [f for f in model_fields if f in df.columns]
df_additional_fields = [f for f in df.columns if f not in model_fields]
return df[df_model_fields + df_additional_fields]
10 changes: 10 additions & 0 deletions tests/test_roadway/test_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
load_roadway_from_dir,
write_roadway,
)
from network_wrangler.models.roadway.tables import RoadLinksTable
from network_wrangler.roadway import diff_nets
from network_wrangler.roadway.io import (
convert_roadway_file_serialization,
Expand Down Expand Up @@ -156,6 +157,15 @@ def test_roadway_geojson_read_write_read(request, example_dir, test_out_dir, ex,
f"{int(t_read // 60): 02d}:{int(t_read % 60): 02d} ... {ex} read from {io_format}"
)
assert isinstance(net, RoadwayNetwork)
# make sure field order is as expected.
skip_ordered = ["geometry"]
_shared_ordered_fields = [
c for c in RoadLinksTable.__fields__ if c in net.links_df.columns and c not in skip_ordered
]
_output_cols = [c for c in net.links_df.columns if c not in skip_ordered][
0 : len(_shared_ordered_fields)
]
assert _output_cols == _shared_ordered_fields


def test_load_roadway_no_shapes(request, example_dir):
Expand Down

0 comments on commit 63a1963

Please sign in to comment.