Skip to content

Commit

Permalink
Move multiply_nodes_pl() from GenericNode to Node
Browse files Browse the repository at this point in the history
  • Loading branch information
jtuomist committed Feb 25, 2025
1 parent db8d2d1 commit 1f58c83
Show file tree
Hide file tree
Showing 5 changed files with 53 additions and 55 deletions.
12 changes: 7 additions & 5 deletions nodes/finland/hsy.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,11 +369,13 @@ def compute(self) -> ppl.PathsDataFrame:
assert VALUE_COLUMN in df
pdf = ppl.from_pandas(df)

pdf, extra_nodes = self.run_implicit_operations(other_nodes, pdf)
assert pdf is not None
assert len(extra_nodes) == 0, f"Node {self.id} should not have input nodes of the other type."
pdf = pdf.ensure_unit(VALUE_COLUMN, self.unit)
return pdf
dfout, extra_nodes = self.run_implicit_operations(pdf, other_nodes)
if dfout is None:
raise NodeError(self, f"Node {self.id} failed with implicit operations.")
if len(extra_nodes) > 0:
raise NodeError(self, f"Node {self.id} can only have additive and multiplicative input nodes.")
dfout = dfout.ensure_unit(VALUE_COLUMN, self.unit)
return dfout

class HsyEmissionFactor(AdditiveNode, HsyNodeMixin):
default_unit = 'g/kWh'
Expand Down
29 changes: 29 additions & 0 deletions nodes/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -1498,6 +1498,35 @@ def add_nodes_pl( # noqa: C901, PLR0912, PLR0915
df = df.select([YEAR_COLUMN, *meta.dim_ids, VALUE_COLUMN, FORECAST_COLUMN])
return df

def multiply_nodes_pl( # FIXME Make more like add_nodes_pl but allow no inputs
self,
df: ppl.PathsDataFrame | None,
nodes: list[Node],
metric: str | None = None,
keep_nodes: bool = False,
node_multipliers: list[float] | None = None,
unit: Unit | None = None,
start_from_year: int | None = None,
) -> ppl.PathsDataFrame | None:
"""Multiply outputs from the given nodes using inner join and union of dimensions."""
if not nodes and df is None:
return None

if df is not None:
result = df
else:
result = nodes.pop(0).get_output_pl(target_node=self)

for node in nodes:
dfn = node.get_output_pl(target_node=self)
result = result.paths.join_over_index(dfn, how='inner', index_from='union')
result = result.multiply_cols(
[VALUE_COLUMN, f'{VALUE_COLUMN}_right'],
VALUE_COLUMN
).drop(f'{VALUE_COLUMN}_right')

return result

def check(self):
from nodes.metric import DimensionalMetric

Expand Down
62 changes: 13 additions & 49 deletions nodes/simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from common import polars as ppl
from common.i18n import TranslatedString
from nodes.calc import convert_to_co2e, extend_last_historical_value_pl
from nodes.units import Quantity
from nodes.units import Quantity, Unit
from params.param import BoolParameter, NumberParameter, Parameter, StringParameter

from .constants import FORECAST_COLUMN, MIX_QUANTITY, VALUE_COLUMN, YEAR_COLUMN
Expand Down Expand Up @@ -227,54 +227,15 @@ def _get_categorized_inputs(self, nodes: list[Node]) -> tuple[list[Node], list[N

return multiplicative_nodes, additive_nodes, other_nodes

def _multiply_nodes(self, nodes: list[Node], base_df: ppl.PathsDataFrame | None = None) -> ppl.PathsDataFrame | None:
"""Multiply outputs from the given nodes using inner join and union of dimensions."""
if not nodes and base_df is None:
return None

if base_df is not None:
result = base_df
else:
result = nodes.pop(0).get_output_pl(target_node=self)

for node in nodes:
df = node.get_output_pl(target_node=self)
result = result.paths.join_over_index(df, how='inner', index_from='union')
result = result.multiply_cols(
[VALUE_COLUMN, f'{VALUE_COLUMN}_right'],
VALUE_COLUMN
).drop(f'{VALUE_COLUMN}_right')

return result

def _add_nodes(self, nodes: list[Node], base_df: ppl.PathsDataFrame | None = None) -> ppl.PathsDataFrame | None:
"""Add outputs from the given nodes using outer join."""
if not nodes and base_df is None:
return None

if base_df is not None:
result = base_df
else:
result = nodes.pop(0).get_output_pl(target_node=self)

for node in nodes:
df = node.get_output_pl(target_node=self)
if set(df.dim_ids) != set(result.dim_ids):
raise NodeError(
self,
f"Dimensions don't match for implicit addition: {df.dim_ids} vs {result.dim_ids}"
)

result = result.paths.add_with_dims(df, how='outer')
# TODO Should null handling be configurable by parameter?
result = result.with_columns(pl.col(VALUE_COLUMN).fill_null(0.0))

return result

def run_implicit_operations(
self,
df: ppl.PathsDataFrame | None = None,
nodes: list[Node] | None = None,
base_df: ppl.PathsDataFrame | None = None
metric: str | None = None,
keep_nodes: bool = False,
node_multipliers: list[float] | None = None,
unit: Unit | None = None,
start_from_year: int | None = None,
) -> tuple[ppl.PathsDataFrame | None, list[Node]]:
"""
Process all inputs according to their categories.
Expand All @@ -285,8 +246,11 @@ def run_implicit_operations(
nodes = self.input_nodes
mult_nodes, add_nodes, other_nodes = self._get_categorized_inputs(nodes)

result = self._multiply_nodes(mult_nodes, base_df)
result = self._add_nodes(add_nodes, result)

result = self.multiply_nodes_pl(df, mult_nodes, metric, keep_nodes, node_multipliers,
unit, start_from_year)
result = self.add_nodes_pl(result, add_nodes, metric, keep_nodes, node_multipliers,
unit, start_from_year)

return result, other_nodes

Expand Down Expand Up @@ -946,7 +910,7 @@ class FillNewCategoryNode(AdditiveNode):
]

def compute(self):
category = self.get_parameter_value('new_category', required=True, units=False)
category = self.get_parameter_value_str('new_category', required=True)
dim, cat = category.split(':')

df: ppl.PathsDataFrame = self.add_nodes_pl(None, self.input_nodes)
Expand Down
3 changes: 3 additions & 0 deletions nodes/units.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,6 +375,9 @@ def set_one(u: str, long: str | StrPromise, short: str | StrPromise | None = Non
set_one('personal_activity', long=_('minutes per day per person'), short=_('min/d/cap'))
set_one('gigaEUR', long=_('billion euros'), short=_('B€'))
set_one('megasolid_cubic_meter', long='million solid m³', short='M m³ (solid)')
if 'kilowatt_hour' in _babel_units:
del _babel_units['kilowatt_hour'] # Otherwise fails with compound units.
set_one('kilowatt_hour', long=_('kilowatt hour'), short='kWh')

loc = Loc('de')
loc._data['unit_patterns']['duration-year']['short'] = dict(one='a')
Expand Down

0 comments on commit 1f58c83

Please sign in to comment.