Skip to content

Commit

Permalink
select by tags (#1115)
Browse files Browse the repository at this point in the history
  • Loading branch information
jperez999 authored Sep 14, 2021
1 parent 9f141c6 commit 8e8a6f6
Show file tree
Hide file tree
Showing 10 changed files with 168 additions and 14 deletions.
17 changes: 10 additions & 7 deletions nvtabular/columns/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,16 +213,19 @@ def column_names(self):
return list(self.column_schemas.keys())

def apply(self, selector):
if selector and selector.names:
return self.select_by_name(selector.names)
else:
return self
if selector:
schema = Schema()
if selector.names:
schema += self.select_by_name(selector.names)
if selector.tags:
schema += self.select_by_tag(selector.tags)
return schema
return self

def apply_inverse(self, selector):
if selector:
return self - self.select_by_name(selector.names)
else:
return self
return self

def select_by_tag(self, tags):
if not isinstance(tags, list):
Expand All @@ -231,7 +234,7 @@ def select_by_tag(self, tags):
selected_schemas = {}

for _, column_schema in self.column_schemas.items():
if all(x in column_schema.tags for x in tags):
if any(x in column_schema.tags for x in tags):
selected_schemas[column_schema.name] = column_schema

return Schema(selected_schemas)
Expand Down
27 changes: 23 additions & 4 deletions nvtabular/columns/selector.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
from typing import List
from typing import List, Union

import nvtabular
from nvtabular.tags import Tags


class ColumnSelector:
Expand All @@ -35,9 +36,16 @@ class ColumnSelector:
of nesting tuples inside the list of names)
"""

def __init__(self, names: List[str] = None, subgroups: List["ColumnSelector"] = None):
def __init__(
self,
names: List[str] = None,
subgroups: List["ColumnSelector"] = None,
tags: List[Union[Tags, str]] = None,
):
self._names = names if names is not None else []
self.subgroups = subgroups if subgroups else []
self._tags = tags if tags is not None else []
self.subgroups = subgroups if subgroups is not None else []

if isinstance(self._names, nvtabular.WorkflowNode):
raise TypeError("ColumnSelectors can not contain WorkflowNodes")

Expand All @@ -60,6 +68,10 @@ def __init__(self, names: List[str] = None, subgroups: List["ColumnSelector"] =
self._names = plain_names
self._nested_check()

@property
def tags(self):
return list(dict.fromkeys(self._tags).keys())

@property
def names(self):
names = []
Expand Down Expand Up @@ -92,7 +104,14 @@ def __add__(self, other):
elif isinstance(other, nvtabular.WorkflowNode):
return other + self
elif isinstance(other, ColumnSelector):
return ColumnSelector(self._names + other._names, self.subgroups + other.subgroups)

return ColumnSelector(
self._names + other._names,
self.subgroups + other.subgroups,
tags=self._tags + other._tags,
)
elif isinstance(other, Tags):
return ColumnSelector(self._names, self.subgroups, tags=self._tags + [other])
else:
if isinstance(other, str):
other = [other]
Expand Down
7 changes: 7 additions & 0 deletions nvtabular/ops/fill.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,13 @@ def inference_initialize(self, col_selector, inference_config):
def compute_output_schema(self, input_schema: Schema, col_selector: ColumnSelector) -> Schema:
if not col_selector:
col_selector = ColumnSelector(input_schema.column_names)
if col_selector.tags:
tags_col_selector = ColumnSelector(tags=col_selector.tags)
filtered_schema = input_schema.apply(tags_col_selector)
col_selector += ColumnSelector(filtered_schema.column_names)

# zero tags because already filtered
col_selector._tags = []
output_schema = Schema()
for column_name in col_selector.names:
column_schema = input_schema.column_schemas[column_name]
Expand Down
7 changes: 7 additions & 0 deletions nvtabular/ops/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,13 @@ def _dtypes(self):
return numpy.int64

def compute_output_schema(self, input_schema: Schema, col_selector: ColumnSelector) -> Schema:
if col_selector.tags:
tags_col_selector = ColumnSelector(tags=col_selector.tags)
filtered_schema = input_schema.apply(tags_col_selector)
col_selector += ColumnSelector(filtered_schema.column_names)

# zero tags because already filtered
col_selector._tags = []
new_col_selector = self.output_column_names(col_selector)
new_list = []
for name in col_selector.names:
Expand Down
9 changes: 9 additions & 0 deletions nvtabular/ops/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,15 @@ def compute_output_schema(self, input_schema: Schema, col_selector: ColumnSelect
"""
if not col_selector:
col_selector = ColumnSelector(input_schema.column_names)

if col_selector.tags:
tags_col_selector = ColumnSelector(tags=col_selector.tags)
filtered_schema = input_schema.apply(tags_col_selector)
col_selector += ColumnSelector(filtered_schema.column_names)

# zero tags because already filtered
col_selector._tags = []

col_selector = self.output_column_names(col_selector)

for column_name in col_selector.names:
Expand Down
7 changes: 7 additions & 0 deletions nvtabular/ops/rename.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,13 @@ def transform(self, col_selector: ColumnSelector, df: DataFrameType) -> DataFram
def compute_output_schema(self, input_schema: Schema, col_selector: ColumnSelector) -> Schema:
if not col_selector:
col_selector = ColumnSelector(input_schema.column_names)
if col_selector.tags:
tags_col_selector = ColumnSelector(tags=col_selector.tags)
filtered_schema = input_schema.apply(tags_col_selector)
col_selector += ColumnSelector(filtered_schema.column_names)

# zero tags because already filtered
col_selector._tags = []
output_schema = Schema()
for column_name in input_schema.column_schemas:
new_names = self.output_column_names(ColumnSelector(column_name))
Expand Down
2 changes: 1 addition & 1 deletion nvtabular/workflow/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def compute_schemas(self, root_schema):
len(self.parents) == 1
and isinstance(self.parents[0].op, internal.ConcatColumns)
and self.parents[0].selector
and self.parents[0].selector.names
and (self.parents[0].selector.names)
):

self.selector = self.parents[0].selector
Expand Down
24 changes: 23 additions & 1 deletion tests/unit/columns/test_column_schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ def test_dataset_schema_column_names():
assert ds_schema.column_names == ["x", "y", "z"]


def test_applying_selector_to_schema_selects_relevant_columns():
def test_applying_selector_to_schema_selects_by_name():
schema = Schema(["a", "b", "c", "d", "e"])
selector = ColumnSelector(["a", "b"])
result = schema.apply(selector)
Expand All @@ -173,6 +173,28 @@ def test_applying_selector_to_schema_selects_relevant_columns():
assert result == schema


def test_applying_selector_to_schema_selects_by_tags():
schema1 = ColumnSchema("col1", tags=["a", "b", "c"])
schema2 = ColumnSchema("col2", tags=["b", "c", "d"])

schema = Schema([schema1, schema2])
selector = ColumnSelector(tags=["a", "b"])
result = schema.apply(selector)

assert result.column_names == schema.column_names


def test_applying_selector_to_schema_selects_by_name_or_tags():
schema1 = ColumnSchema("col1")
schema2 = ColumnSchema("col2", tags=["b", "c", "d"])

schema = Schema([schema1, schema2])
selector = ColumnSelector(["col1"], tags=["a", "b"])
result = schema.apply(selector)

assert result.column_names == schema.column_names


def test_applying_inverse_selector_to_schema_selects_relevant_columns():
schema = Schema(["a", "b", "c", "d", "e"])
selector = ColumnSelector(["a", "b"])
Expand Down
43 changes: 43 additions & 0 deletions tests/unit/columns/test_column_selector.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

from nvtabular.columns import ColumnSelector
from nvtabular.ops import Operator
from nvtabular.tags import Tags
from nvtabular.workflow import WorkflowNode


Expand Down Expand Up @@ -147,3 +148,45 @@ def test_rshift_operator_onto_selector_creates_node_with_selector():
assert isinstance(output_node, WorkflowNode)
assert output_node.selector == selector
assert output_node.parents == []


def test_construct_column_selector_with_tags():
target_tags = [Tags.CATEGORICAL, "custom"]
selector = ColumnSelector(tags=target_tags)
assert selector.tags == target_tags


def test_returned_tags_are_unique():
selector = ColumnSelector(tags=["a", "b", "a"])
assert selector.tags == ["a", "b"]


def test_addition_combines_tags():
selector1 = ColumnSelector(tags=["a", "b", "c"])
selector2 = ColumnSelector(tags=["g", "h", "i"])
combined = selector1 + selector2

assert combined.tags == ["a", "b", "c", "g", "h", "i"]


def test_addition_combines_names_and_tags():
selector1 = ColumnSelector(["a", "b", "c"])
selector2 = ColumnSelector(tags=["g", "h", "i"])
combined = selector1 + selector2

assert combined.names == ["a", "b", "c"]
assert combined.tags == ["g", "h", "i"]


def test_addition_enum_tags():
selector1 = ColumnSelector(tags=["a", "b", "c"])
combined = selector1 + Tags.CATEGORICAL

assert combined.tags == ["a", "b", "c", Tags.CATEGORICAL]

selector2 = ColumnSelector(["a", "b", "c", ["d", "e", "f"]])
combined = selector2 + Tags.CATEGORICAL

assert combined._names == ["a", "b", "c"]
assert combined.subgroups == [ColumnSelector(["d", "e", "f"])]
assert combined.tags == [Tags.CATEGORICAL]
39 changes: 38 additions & 1 deletion tests/unit/workflow/test_workflow_schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import pytest

from nvtabular import Dataset, Workflow, ops
from nvtabular.columns import ColumnSelector, Schema
from nvtabular.columns import ColumnSchema, ColumnSelector, Schema


def test_fit_schema():
Expand Down Expand Up @@ -130,6 +130,43 @@ def test_fit_schema_works_with_node_dependencies():
assert workflow1.output_schema.column_names == ["TE_x_cost_renamed", "TE_y_cost_renamed"]


# initial column selector works with tags
# filter within the workflow by tags
# test tags correct at output
@pytest.mark.parametrize(
"op",
[
ops.Bucketize([1]),
ops.Rename(postfix="_trim"),
ops.Categorify(),
ops.Categorify(encode_type="combo"),
ops.Clip(0),
ops.DifferenceLag("col1"),
ops.FillMissing(),
ops.Groupby(["col1"]),
ops.HashBucket(1),
ops.HashedCross(1),
ops.JoinGroupby(["col1"]),
ops.ListSlice(0),
ops.LogOp(),
ops.Normalize(),
ops.TargetEncoding(["col1"]),
],
)
def test_workflow_select_by_tags(op):
schema1 = ColumnSchema("col1", tags=["b", "c", "d"])
schema2 = ColumnSchema("col2", tags=["c", "d"])
schema3 = ColumnSchema("col3", tags=["d"])
schema = Schema([schema1, schema2, schema3])

cont_features = ColumnSelector(tags=["c"]) >> op
workflow = Workflow(cont_features)
workflow.fit_schema(schema)

output_cols = op.output_column_names(ColumnSelector(["col1", "col2"]))
assert len(workflow.output_schema.column_names) == len(output_cols.names)


@pytest.mark.parametrize("engine", ["parquet"])
def test_schema_write_read_dataset(tmpdir, dataset, engine):
cat_names = ["name-cat", "name-string"] if engine == "parquet" else ["name-string"]
Expand Down

0 comments on commit 8e8a6f6

Please sign in to comment.