diff --git a/nvtabular/columns/schema.py b/nvtabular/columns/schema.py index 9534d17f77f..ff4a3a29fc4 100644 --- a/nvtabular/columns/schema.py +++ b/nvtabular/columns/schema.py @@ -213,16 +213,19 @@ def column_names(self): return list(self.column_schemas.keys()) def apply(self, selector): - if selector and selector.names: - return self.select_by_name(selector.names) - else: - return self + if selector: + schema = Schema() + if selector.names: + schema += self.select_by_name(selector.names) + if selector.tags: + schema += self.select_by_tag(selector.tags) + return schema + return self def apply_inverse(self, selector): if selector: return self - self.select_by_name(selector.names) - else: - return self + return self def select_by_tag(self, tags): if not isinstance(tags, list): @@ -231,7 +234,7 @@ def select_by_tag(self, tags): selected_schemas = {} for _, column_schema in self.column_schemas.items(): - if all(x in column_schema.tags for x in tags): + if any(x in column_schema.tags for x in tags): selected_schemas[column_schema.name] = column_schema return Schema(selected_schemas) diff --git a/nvtabular/columns/selector.py b/nvtabular/columns/selector.py index ca34790ba68..d7e4786ff89 100644 --- a/nvtabular/columns/selector.py +++ b/nvtabular/columns/selector.py @@ -13,9 +13,10 @@ # See the License for the specific language governing permissions and # limitations under the License. # -from typing import List +from typing import List, Union import nvtabular +from nvtabular.tags import Tags class ColumnSelector: @@ -35,9 +36,16 @@ class ColumnSelector: of nesting tuples inside the list of names) """ - def __init__(self, names: List[str] = None, subgroups: List["ColumnSelector"] = None): + def __init__( + self, + names: List[str] = None, + subgroups: List["ColumnSelector"] = None, + tags: List[Union[Tags, str]] = None, + ): self._names = names if names is not None else [] - self.subgroups = subgroups if subgroups else [] + self._tags = tags if tags is not None else [] + self.subgroups = subgroups if subgroups is not None else [] + if isinstance(self._names, nvtabular.WorkflowNode): raise TypeError("ColumnSelectors can not contain WorkflowNodes") @@ -60,6 +68,10 @@ def __init__(self, names: List[str] = None, subgroups: List["ColumnSelector"] = self._names = plain_names self._nested_check() + @property + def tags(self): + return list(dict.fromkeys(self._tags).keys()) + @property def names(self): names = [] @@ -92,7 +104,14 @@ def __add__(self, other): elif isinstance(other, nvtabular.WorkflowNode): return other + self elif isinstance(other, ColumnSelector): - return ColumnSelector(self._names + other._names, self.subgroups + other.subgroups) + + return ColumnSelector( + self._names + other._names, + self.subgroups + other.subgroups, + tags=self._tags + other._tags, + ) + elif isinstance(other, Tags): + return ColumnSelector(self._names, self.subgroups, tags=self._tags + [other]) else: if isinstance(other, str): other = [other] diff --git a/nvtabular/ops/fill.py b/nvtabular/ops/fill.py index b5350062919..914fa1ff48c 100644 --- a/nvtabular/ops/fill.py +++ b/nvtabular/ops/fill.py @@ -71,6 +71,13 @@ def inference_initialize(self, col_selector, inference_config): def compute_output_schema(self, input_schema: Schema, col_selector: ColumnSelector) -> Schema: if not col_selector: col_selector = ColumnSelector(input_schema.column_names) + if col_selector.tags: + tags_col_selector = ColumnSelector(tags=col_selector.tags) + filtered_schema = input_schema.apply(tags_col_selector) + col_selector += ColumnSelector(filtered_schema.column_names) + + # zero tags because already filtered + col_selector._tags = [] output_schema = Schema() for column_name in col_selector.names: column_schema = input_schema.column_schemas[column_name] diff --git a/nvtabular/ops/groupby.py b/nvtabular/ops/groupby.py index 4e9c1ed8c8b..b467d61e954 100644 --- a/nvtabular/ops/groupby.py +++ b/nvtabular/ops/groupby.py @@ -141,6 +141,13 @@ def _dtypes(self): return numpy.int64 def compute_output_schema(self, input_schema: Schema, col_selector: ColumnSelector) -> Schema: + if col_selector.tags: + tags_col_selector = ColumnSelector(tags=col_selector.tags) + filtered_schema = input_schema.apply(tags_col_selector) + col_selector += ColumnSelector(filtered_schema.column_names) + + # zero tags because already filtered + col_selector._tags = [] new_col_selector = self.output_column_names(col_selector) new_list = [] for name in col_selector.names: diff --git a/nvtabular/ops/operator.py b/nvtabular/ops/operator.py index 20d9d748e61..b88add3fd1e 100644 --- a/nvtabular/ops/operator.py +++ b/nvtabular/ops/operator.py @@ -74,6 +74,15 @@ def compute_output_schema(self, input_schema: Schema, col_selector: ColumnSelect """ if not col_selector: col_selector = ColumnSelector(input_schema.column_names) + + if col_selector.tags: + tags_col_selector = ColumnSelector(tags=col_selector.tags) + filtered_schema = input_schema.apply(tags_col_selector) + col_selector += ColumnSelector(filtered_schema.column_names) + + # zero tags because already filtered + col_selector._tags = [] + col_selector = self.output_column_names(col_selector) for column_name in col_selector.names: diff --git a/nvtabular/ops/rename.py b/nvtabular/ops/rename.py index e2bdaec00a6..bdf1091815e 100644 --- a/nvtabular/ops/rename.py +++ b/nvtabular/ops/rename.py @@ -60,6 +60,13 @@ def transform(self, col_selector: ColumnSelector, df: DataFrameType) -> DataFram def compute_output_schema(self, input_schema: Schema, col_selector: ColumnSelector) -> Schema: if not col_selector: col_selector = ColumnSelector(input_schema.column_names) + if col_selector.tags: + tags_col_selector = ColumnSelector(tags=col_selector.tags) + filtered_schema = input_schema.apply(tags_col_selector) + col_selector += ColumnSelector(filtered_schema.column_names) + + # zero tags because already filtered + col_selector._tags = [] output_schema = Schema() for column_name in input_schema.column_schemas: new_names = self.output_column_names(ColumnSelector(column_name)) diff --git a/nvtabular/workflow/node.py b/nvtabular/workflow/node.py index 38e5f919f77..8f1c9ee9dc6 100644 --- a/nvtabular/workflow/node.py +++ b/nvtabular/workflow/node.py @@ -75,7 +75,7 @@ def compute_schemas(self, root_schema): len(self.parents) == 1 and isinstance(self.parents[0].op, internal.ConcatColumns) and self.parents[0].selector - and self.parents[0].selector.names + and (self.parents[0].selector.names) ): self.selector = self.parents[0].selector diff --git a/tests/unit/columns/test_column_schemas.py b/tests/unit/columns/test_column_schemas.py index 3716c260611..015c40c97e8 100644 --- a/tests/unit/columns/test_column_schemas.py +++ b/tests/unit/columns/test_column_schemas.py @@ -160,7 +160,7 @@ def test_dataset_schema_column_names(): assert ds_schema.column_names == ["x", "y", "z"] -def test_applying_selector_to_schema_selects_relevant_columns(): +def test_applying_selector_to_schema_selects_by_name(): schema = Schema(["a", "b", "c", "d", "e"]) selector = ColumnSelector(["a", "b"]) result = schema.apply(selector) @@ -173,6 +173,28 @@ def test_applying_selector_to_schema_selects_relevant_columns(): assert result == schema +def test_applying_selector_to_schema_selects_by_tags(): + schema1 = ColumnSchema("col1", tags=["a", "b", "c"]) + schema2 = ColumnSchema("col2", tags=["b", "c", "d"]) + + schema = Schema([schema1, schema2]) + selector = ColumnSelector(tags=["a", "b"]) + result = schema.apply(selector) + + assert result.column_names == schema.column_names + + +def test_applying_selector_to_schema_selects_by_name_or_tags(): + schema1 = ColumnSchema("col1") + schema2 = ColumnSchema("col2", tags=["b", "c", "d"]) + + schema = Schema([schema1, schema2]) + selector = ColumnSelector(["col1"], tags=["a", "b"]) + result = schema.apply(selector) + + assert result.column_names == schema.column_names + + def test_applying_inverse_selector_to_schema_selects_relevant_columns(): schema = Schema(["a", "b", "c", "d", "e"]) selector = ColumnSelector(["a", "b"]) diff --git a/tests/unit/columns/test_column_selector.py b/tests/unit/columns/test_column_selector.py index 3bb603e787b..05ed4d40c46 100644 --- a/tests/unit/columns/test_column_selector.py +++ b/tests/unit/columns/test_column_selector.py @@ -17,6 +17,7 @@ from nvtabular.columns import ColumnSelector from nvtabular.ops import Operator +from nvtabular.tags import Tags from nvtabular.workflow import WorkflowNode @@ -147,3 +148,45 @@ def test_rshift_operator_onto_selector_creates_node_with_selector(): assert isinstance(output_node, WorkflowNode) assert output_node.selector == selector assert output_node.parents == [] + + +def test_construct_column_selector_with_tags(): + target_tags = [Tags.CATEGORICAL, "custom"] + selector = ColumnSelector(tags=target_tags) + assert selector.tags == target_tags + + +def test_returned_tags_are_unique(): + selector = ColumnSelector(tags=["a", "b", "a"]) + assert selector.tags == ["a", "b"] + + +def test_addition_combines_tags(): + selector1 = ColumnSelector(tags=["a", "b", "c"]) + selector2 = ColumnSelector(tags=["g", "h", "i"]) + combined = selector1 + selector2 + + assert combined.tags == ["a", "b", "c", "g", "h", "i"] + + +def test_addition_combines_names_and_tags(): + selector1 = ColumnSelector(["a", "b", "c"]) + selector2 = ColumnSelector(tags=["g", "h", "i"]) + combined = selector1 + selector2 + + assert combined.names == ["a", "b", "c"] + assert combined.tags == ["g", "h", "i"] + + +def test_addition_enum_tags(): + selector1 = ColumnSelector(tags=["a", "b", "c"]) + combined = selector1 + Tags.CATEGORICAL + + assert combined.tags == ["a", "b", "c", Tags.CATEGORICAL] + + selector2 = ColumnSelector(["a", "b", "c", ["d", "e", "f"]]) + combined = selector2 + Tags.CATEGORICAL + + assert combined._names == ["a", "b", "c"] + assert combined.subgroups == [ColumnSelector(["d", "e", "f"])] + assert combined.tags == [Tags.CATEGORICAL] diff --git a/tests/unit/workflow/test_workflow_schemas.py b/tests/unit/workflow/test_workflow_schemas.py index 94d8275c316..e57deca79d0 100644 --- a/tests/unit/workflow/test_workflow_schemas.py +++ b/tests/unit/workflow/test_workflow_schemas.py @@ -18,7 +18,7 @@ import pytest from nvtabular import Dataset, Workflow, ops -from nvtabular.columns import ColumnSelector, Schema +from nvtabular.columns import ColumnSchema, ColumnSelector, Schema def test_fit_schema(): @@ -130,6 +130,43 @@ def test_fit_schema_works_with_node_dependencies(): assert workflow1.output_schema.column_names == ["TE_x_cost_renamed", "TE_y_cost_renamed"] +# initial column selector works with tags +# filter within the workflow by tags +# test tags correct at output +@pytest.mark.parametrize( + "op", + [ + ops.Bucketize([1]), + ops.Rename(postfix="_trim"), + ops.Categorify(), + ops.Categorify(encode_type="combo"), + ops.Clip(0), + ops.DifferenceLag("col1"), + ops.FillMissing(), + ops.Groupby(["col1"]), + ops.HashBucket(1), + ops.HashedCross(1), + ops.JoinGroupby(["col1"]), + ops.ListSlice(0), + ops.LogOp(), + ops.Normalize(), + ops.TargetEncoding(["col1"]), + ], +) +def test_workflow_select_by_tags(op): + schema1 = ColumnSchema("col1", tags=["b", "c", "d"]) + schema2 = ColumnSchema("col2", tags=["c", "d"]) + schema3 = ColumnSchema("col3", tags=["d"]) + schema = Schema([schema1, schema2, schema3]) + + cont_features = ColumnSelector(tags=["c"]) >> op + workflow = Workflow(cont_features) + workflow.fit_schema(schema) + + output_cols = op.output_column_names(ColumnSelector(["col1", "col2"])) + assert len(workflow.output_schema.column_names) == len(output_cols.names) + + @pytest.mark.parametrize("engine", ["parquet"]) def test_schema_write_read_dataset(tmpdir, dataset, engine): cat_names = ["name-cat", "name-string"] if engine == "parquet" else ["name-string"]