Skip to content

Commit

Permalink
Enhance performance of selection by dictionary.
Browse files Browse the repository at this point in the history
Speeds up explicit ID link selection for St Paul network from 0.9s to .015s
  • Loading branch information
e-lo committed Oct 15, 2024
1 parent bcdc41c commit c630b00
Show file tree
Hide file tree
Showing 13 changed files with 110 additions and 34 deletions.
4 changes: 4 additions & 0 deletions network_wrangler/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,10 @@ class SelectionError(Exception):
"""Raised when there is an issue with a selection."""


class DataframeSelectionError(Exception):
"""Raised when there is an issue with a selection from a dataframe."""


class ShapeAddError(Exception):
"""Raised when there is an issue with adding shapes."""

Expand Down
2 changes: 1 addition & 1 deletion network_wrangler/roadway/links/create.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def data_to_links_df(
WranglerLogger.debug(f"Creating {len(links_df)} links.")
if not isinstance(links_df, pd.DataFrame):
links_df = pd.DataFrame(links_df)
WranglerLogger.debug(f"data_to_links_df.links_df input: \n{links_df.head}.")
# WranglerLogger.debug(f"data_to_links_df.links_df input: \n{links_df.head}.")

v0_link_properties = detect_v0_scoped_link_properties(links_df)
if v0_link_properties:
Expand Down
2 changes: 1 addition & 1 deletion network_wrangler/roadway/links/filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ def filter_links_to_ids(


def filter_links_not_in_ids(
links_df: DataFrame[RoadLinksTable], link_ids: list[int]
links_df: DataFrame[RoadLinksTable], link_ids: Union[list[int], pd.Series]
) -> DataFrame[RoadLinksTable]:
"""Filters links dataframe to NOT have link_ids."""
return links_df.loc[~links_df["model_link_id"].isin(link_ids)]
Expand Down
8 changes: 4 additions & 4 deletions network_wrangler/roadway/links/links.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

def node_ids_in_links(
links_df: DataFrame[RoadLinksTable], nodes_df: Optional[DataFrame[RoadNodesTable]] = None
) -> list[int]:
) -> pd.Series:
"""Returns the unique node_ids in a links dataframe.
Args:
Expand All @@ -26,7 +26,7 @@ def node_ids_in_links(
Returns:
List[int]: list of unique node_ids
"""
_node_ids = list(set(links_df["A"]).union(set(links_df["B"])))
_node_ids = pd.concat([links_df["A"], links_df["B"]]).unique()

if nodes_df is not None:
validate_links_have_nodes(links_df, nodes_df)
Expand All @@ -37,7 +37,7 @@ def node_ids_in_link_ids(
link_ids: list[int],
links_df: DataFrame[RoadLinksTable],
nodes_df: Optional[DataFrame[RoadNodesTable]] = None,
) -> list[int]:
) -> pd.Series:
"""Returns the unique node_ids in a list of link_ids.
Args:
Expand All @@ -61,7 +61,7 @@ def node_ids_unique_to_link_ids(
_unselected_links_df = filter_links_not_in_ids(links_df, link_ids)
unselected_link_node_ids = node_ids_in_links(_unselected_links_df, nodes_df=nodes_df)

return list(set(selected_link_node_ids) - set(unselected_link_node_ids))
return selected_link_node_ids[~selected_link_node_ids.isin(unselected_link_node_ids)].tolist()


def shape_ids_in_links(
Expand Down
6 changes: 3 additions & 3 deletions network_wrangler/roadway/model_roadway.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,7 @@ def _generate_ml_node_id_from_range(nodes_df, links_df, node_id_range: tuple[int
available for provided range: {node_id_range}."
raise ValueError(msg)
new_ml_node_ids = list(avail_ml_node_ids)[: len(og_ml_node_ids)]
return dict(zip(og_ml_node_ids, new_ml_node_ids))
return dict(zip(og_ml_node_ids.tolist(), new_ml_node_ids))


def _generate_ml_link_id_lookup_from_scalar(links_df: DataFrame[RoadLinksTable], scalar: int):
Expand All @@ -265,11 +265,11 @@ def _generate_ml_link_id_lookup_from_scalar(links_df: DataFrame[RoadLinksTable],
def _generate_ml_node_id_lookup_from_scalar(nodes_df, links_df, scalar: int):
"""Generate a lookup for managed lane node ids to their general purpose lane counterparts."""
og_ml_node_ids = node_ids_in_links(links_df.of_type.managed, nodes_df)
node_id_list = [i + scalar for i in og_ml_node_ids]
node_id_list = og_ml_node_ids + scalar
if nodes_df.model_node_id.isin(node_id_list).any():
msg = f"New node ids generated by scalar {scalar} already exist. Try a different scalar."
raise ValueError(msg)
return dict(zip(og_ml_node_ids, node_id_list))
return dict(zip(og_ml_node_ids.tolist(), node_id_list.tolist()))


def model_links_nodes_from_net(
Expand Down
4 changes: 3 additions & 1 deletion network_wrangler/roadway/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -658,7 +658,9 @@ def has_link(self, ab: tuple) -> bool:
ab: Tuple of values corresponding with A and B.
"""
sel_a, sel_b = ab
has_link = self.links_df[self.links_df[["A", "B"]]].isin({"A": sel_a, "B": sel_b}).any()
has_link = (
self.links_df[self.links_df[["A", "B"]]].isin_dict({"A": sel_a, "B": sel_b}).any()
)
return has_link

def is_connected(self, mode: str) -> bool:
Expand Down
4 changes: 1 addition & 3 deletions network_wrangler/roadway/segment.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,9 +201,7 @@ def get_node_id(self, node_selection_data: SelectNodeDict) -> int:

def get_node(self, node_selection_data: SelectNodeDict):
"""Get single node based on the selection data."""
sel_d = node_selection_data.explicit_id_selection_dict
_sel_node_mask = self.net.nodes_df.isin(sel_d).any(axis=1)
node_df = self.net.nodes_df.loc[_sel_node_mask]
node_df = self.net.nodes_df.isin_dict(node_selection_data.explicit_id_selection_dict)
if len(node_df) != 1:
msg = f"Node selection not unique. Found {len(node_df)} nodes."
raise SegmentSelectionError(msg)
Expand Down
25 changes: 7 additions & 18 deletions network_wrangler/roadway/selection.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,10 +223,9 @@ def validate_selection(self, sel_data: SelectFacility) -> SelectFacility:

def _perform_selection(self):
# 1. Initial selection based on selection type
WranglerLogger.debug(
f"Initial link selection type: \
{self.feature_types}.{self.selection_data.selection_type}"
)
msg = f"Initial link selection type: {self.feature_types}.{self.selection_data.selection_type}"
WranglerLogger.debug(msg)

if self.selection_type == "explicit_ids":
_selected_links_df = self._select_explicit_link_id()

Expand Down Expand Up @@ -267,20 +266,10 @@ def _perform_selection(self):
def _select_explicit_link_id(self):
"""Select links based on a explicit link id in selection_dict."""
WranglerLogger.info("Selecting using explicit link identifiers.")
WranglerLogger.debug(f"Explicit link selection dictionary: {self.explicit_id_sel_dict}")
missing_values = {
col: list(set(values) - set(self.net.links_df[col]))
for col, values in self.explicit_id_sel_dict.items()
}
missing_df = pd.DataFrame(missing_values)
if len(missing_df) > 0:
WranglerLogger.warning(f"Missing explicit link selections: \n{missing_df}")
if not self.ignore_missing:
msg = "Missing explicit link selections."
raise SelectionError(msg)

_sel_links_mask = self.net.links_df.isin(self.explicit_id_sel_dict).any(axis=1)
_sel_links_df = self.net.links_df.loc[_sel_links_mask]
# WranglerLogger.debug(f"Explicit link selection dictionary: {self.explicit_id_sel_dict}")
_sel_links_df = self.net.links_df.isin_dict(
self.explicit_id_sel_dict, ignore_missing=self.ignore_missing
)

return _sel_links_df

Expand Down
2 changes: 1 addition & 1 deletion network_wrangler/roadway/subnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ def num_links(self):
return len(self.subnet_links_df)

@property
def subnet_nodes(self) -> list[int]:
def subnet_nodes(self) -> pd.Series:
"""List of node_ids in the subnet."""
if self.subnet_links_df is None:
msg = "Must set self.subnet_links_df before accessing subnet_nodes."
Expand Down
2 changes: 1 addition & 1 deletion network_wrangler/transit/feed/shapes.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,7 @@ def shapes_for_road_links(
]
WranglerLogger.debug(
f"DEBUG AB: \n\
{shape_links_w_links[shape_links_w_links[['A', 'B']].isin(_debug_AB).all(axis=1)]}"
{shape_links_w_links[['A', 'B']].isin_dict(_debug_AB)}"
)

"""
Expand Down
30 changes: 30 additions & 0 deletions network_wrangler/utils/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from numpy import ndarray
from shapely import wkt

from ..errors import DataframeSelectionError
from ..logger import WranglerLogger
from ..params import LAT_LON_CRS

Expand Down Expand Up @@ -703,3 +704,32 @@ def concat_with_attr(dfs: list[pd.DataFrame], **kwargs) -> pd.DataFrame:
df = pd.concat(dfs, **kwargs)
df.attrs = attrs
return df


def isin_dict(df: pd.DataFrame, d: dict, ignore_missing: bool = True) -> pd.DataFrame:
"""Filter the dataframe using a dictionary - faster than using isin.
Uses merge to filter the dataframe by the dictionary keys and values.
"""
sel_links_mask = np.zeros(len(df), dtype=bool)
missing = {}
for col, vals in d.items():
if vals is None:
continue
if col not in df.columns:
msg = f"Key {col} not in dataframe columns."
raise DataframeSelectionError(msg)
vals_s = pd.DataFrame({col: vals})
index_name = df.index.name if df.index.name is not None else "index"
_df = df[[col]].reset_index(names=index_name)
merged_df = _df.merge(vals_s, on=col, how="outer", indicator=True)
selected = merged_df[merged_df["_merge"] == "both"].set_index(index_name)
sel_links_mask |= df.index.isin(selected.index)
missing[col] = merged_df.loc[merged_df["_merge"] == "right_only", col].tolist()
if len(missing[col]):
WranglerLogger.warning(f"Missing values in selection dict for {col}: {missing}")
if not ignore_missing and any(missing):
msg = "Missing values in selection dict."
raise DataframeSelectionError(msg)

return df.loc[sel_links_mask]
15 changes: 14 additions & 1 deletion network_wrangler/utils/df_accessors.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import pandas as pd

from ..logger import WranglerLogger
from .data import dict_to_query
from .data import dict_to_query, isin_dict


@pd.api.extensions.register_dataframe_accessor("dict_query")
Expand Down Expand Up @@ -77,3 +77,16 @@ def __call__(self):
_value = str(self._obj.values).encode()
hash = hashlib.sha1(_value).hexdigest()
return hash


@pd.api.extensions.register_dataframe_accessor("isin_dict")
class Isin_dict:
"""Faster implimentation of isin for querying dataframes with dictionary."""

def __init__(self, pandas_obj):
"""Initialization function for the dataframe hash."""
self._obj = pandas_obj

def __call__(self, d: dict, **kwargs) -> pd.DataFrame:
"""Function to perform the faster dictionary isin()."""
return isin_dict(self._obj, d, **kwargs)
40 changes: 40 additions & 0 deletions tests/test_utils/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,15 @@
import pytest
from pandas import testing as tm

from network_wrangler.errors import DataframeSelectionError
from network_wrangler.logger import WranglerLogger
from network_wrangler.utils.data import (
DataSegmentationError,
InvalidJoinFieldError,
MissingPropertiesError,
dict_to_query,
diff_dfs,
isin_dict,
list_like_columns,
segment_data_by_selection,
segment_data_by_selection_min_overlap,
Expand Down Expand Up @@ -568,3 +570,41 @@ def test_update_props_from_one_to_many():
)
# Check if the updated_df matches the expected_df
pd.testing.assert_frame_equal(updated_df, expected_df)


def test_isin_dict_basic():
df = pd.DataFrame({"col1": [1, 2, 3, 4, 5], "col2": ["a", "b", "c", "d", "e"]})
d = {"col1": [2, 4], "col2": ["c", "d"]}
expected_df = pd.DataFrame({"col1": [2, 3, 4], "col2": ["b", "c", "d"]})
result_df = isin_dict(df, d)
pd.testing.assert_frame_equal(result_df.reset_index(drop=True), expected_df)


def test_isin_dict_missing_values():
df = pd.DataFrame({"col1": [1, 2, 3, 4, 5], "col2": ["a", "b", "c", "d", "e"]})
d = {"col1": [2, 6], "col2": ["b", "e"]}
expected_df = pd.DataFrame({"col1": [2, 5], "col2": ["b", "e"]})
result_df = isin_dict(df, d)
pd.testing.assert_frame_equal(result_df.reset_index(drop=True), expected_df)


def test_isin_dict_ignore_missing_false():
df = pd.DataFrame({"col1": [1, 2, 3, 4, 5], "col2": ["a", "b", "c", "d", "e"]})
d = {"col1": [2, 6], "col2": ["b", "f"]}
with pytest.raises(DataframeSelectionError):
isin_dict(df, d, ignore_missing=False)


def test_isin_dict_non_existing_column():
df = pd.DataFrame({"col1": [1, 2, 3, 4, 5], "col2": ["a", "b", "c", "d", "e"]})
d = {"col1": [2, 4], "col3": ["x", "y"]}
with pytest.raises(DataframeSelectionError):
isin_dict(df, d)


def test_isin_dict_empty_dataframe():
df = pd.DataFrame(columns=["col1", "col2"])
d = {"col1": [2, 4], "col2": ["b", "d"]}
expected_df = pd.DataFrame(columns=["col1", "col2"])
result_df = isin_dict(df, d)
pd.testing.assert_frame_equal(result_df, expected_df)

0 comments on commit c630b00

Please sign in to comment.