Skip to content

Commit

Permalink
Support null dimension categories in simple.AdditiveAction
Browse files Browse the repository at this point in the history
  • Loading branch information
mechenich committed Feb 24, 2025
1 parent e78f2f2 commit d2a49a9
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 2 deletions.
1 change: 1 addition & 0 deletions nodes/actions/action.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ class ActionNode(Node):
allowed_parameters: ClassVar[Sequence[Parameter]] = [
ENABLED_PARAM,
NumberParameter(local_id='multiplier', label='Multiplies the output', is_customizable=True),
BoolParameter(local_id='allow_null_categories', description='Allow null dimension categories', is_customizable=False),
]

def __init_subclass__(cls) -> None:
Expand Down
3 changes: 3 additions & 0 deletions nodes/actions/simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@ class AdditiveAction(ActionNode):
def compute_effect(self):
df = self.get_input_dataset_pl()

if self.get_parameter_value('allow_null_categories', required=False):
self.allow_null_categories = True

for m in self.output_metrics.values():
if not self.is_enabled():
df = df.with_columns(pl.when(pl.col(m.column_id).is_null()).then(None)
Expand Down
4 changes: 2 additions & 2 deletions nodes/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -866,7 +866,7 @@ def _get_output_for_target(self, df: ppl.PathsDataFrame, target_node: Node) -> p
raise NodeError(self, 'Dimension %s not in output df' % dim_id)
filter_expr = pl.col(dim_id).is_in(cat_ids)
if edge_dim.exclude:
filter_expr = ~filter_expr
filter_expr = pl.col(dim_id).is_null() | ~filter_expr
df = df.filter(filter_expr)
if len(df) == 0:
raise NodeError(self, 'No rows left after filtering by %s' % dim_id)
Expand Down Expand Up @@ -990,7 +990,7 @@ def validate_output(self, df: ppl.PathsDataFrame) -> None: # noqa: C901, PLR091

cats = set(df[dim_id].cast(pl.Utf8).unique())
if getattr(self, 'allow_null_categories', None):
cats -= {''}
cats -= {'', None}

dim_cats = dim.get_cat_ids()
diff = cats - dim_cats
Expand Down

0 comments on commit d2a49a9

Please sign in to comment.