Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(classification): support for regex based custom infotypes #8177

Merged
merged 4 commits into from
Jun 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion metadata-ingestion/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ def get_long_description():
"pandas",
"cryptography",
"msal",
"acryl-datahub-classify==0.0.6",
"acryl-datahub-classify==0.0.7",
# spacy version restricted to reduce backtracking, used by acryl-datahub-classify,
"spacy==3.4.3",
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from enum import Enum
from typing import Any, Dict, List, Optional

from datahub_classify.helper_classes import ColumnInfo
from datahub_classify.infotype_predictor import predict_infotypes
from datahub_classify.reference_input import input1 as default_config
from pydantic.class_validators import root_validator
from pydantic import validator
from pydantic.fields import Field

from datahub.configuration.common import ConfigModel
Expand Down Expand Up @@ -31,12 +32,20 @@ class DataTypeFactorConfig(ConfigModel):
)


class ValuePredictionType(str, Enum):
REGEX = "regex"
LIBRARY = "library"


class ValuesFactorConfig(ConfigModel):
prediction_type: str
prediction_type: ValuePredictionType
regex: Optional[List[str]] = Field(
default=None,
description="List of regex patterns the column value follows for the info type",
)
library: Optional[List[str]] = Field(description="Library used for prediction")
library: Optional[List[str]] = Field(
default=None, description="Library used for prediction"
)


class PredictionFactorsAndWeights(ConfigModel):
Expand Down Expand Up @@ -68,6 +77,11 @@ class Config:
Values: Optional[ValuesFactorConfig] = Field(default=None, alias="values")


DEFAULT_CLASSIFIER_CONFIG = {
k: InfoTypeConfig.parse_obj(v) for k, v in default_config.items()
}


# TODO: Generate Classification doc (classification.md) from python source.
class DataHubClassifierConfig(ConfigModel):
confidence_level_threshold: float = Field(
Expand All @@ -81,29 +95,58 @@ class DataHubClassifierConfig(ConfigModel):
description=f"List of infotypes to be predicted. By default, all supported infotypes are considered. If specified. this should be subset of {list(default_config.keys())}.",
)
info_types_config: Dict[str, InfoTypeConfig] = Field(
default={k: InfoTypeConfig.parse_obj(v) for k, v in default_config.items()},
default=DEFAULT_CLASSIFIER_CONFIG,
init=False,
description="Configuration details for infotypes. See [reference_input.py](https://github.com/acryldata/datahub-classify/blob/main/datahub-classify/src/datahub_classify/reference_input.py) for default configuration.",
)

@root_validator
def provided_config_selectively_overrides_default_config(cls, values):
override: Dict[str, InfoTypeConfig] = values.get("info_types_config")
base = {k: InfoTypeConfig.parse_obj(v) for k, v in default_config.items()}
for k, v in base.items():
if k not in override.keys():
# use default InfoTypeConfig for info type key if not specified in recipe
values["info_types_config"][k] = v
@validator("info_types_config")
def input_config_selectively_overrides_default_config(cls, info_types_config):
for infotype, infotype_config in DEFAULT_CLASSIFIER_CONFIG.items():
if infotype not in info_types_config:
# if config for some info type is not provided by user, use default config for that info type.
info_types_config[infotype] = infotype_config
else:
for factor, _ in (
override[k].Prediction_Factors_and_Weights.dict().items()
# if config for info type is provided by user but config for its prediction factor is missing,
# use default config for that prediction factor.
for factor, weight in (
info_types_config[infotype]
.Prediction_Factors_and_Weights.dict()
.items()
):
# use default FactorConfig for factor if not specified in recipe
if getattr(override[k], factor) is None:
if (
weight > 0
and getattr(info_types_config[infotype], factor) is None
):
setattr(
values["info_types_config"][k], factor, getattr(v, factor)
info_types_config[infotype],
factor,
getattr(infotype_config, factor),
)
return values
# Custom info type
custom_infotypes = info_types_config.keys() - DEFAULT_CLASSIFIER_CONFIG.keys()

for custom_infotype in custom_infotypes:
custom_infotype_config = info_types_config[custom_infotype]
# for custom infotype, config for every prediction factor must be specified.
for (
factor,
weight,
) in custom_infotype_config.Prediction_Factors_and_Weights.dict().items():
if weight > 0:
assert (
getattr(custom_infotype_config, factor) is not None
), f"Missing Configuration for Prediction Factor {factor} for Custom Info Type {custom_infotype}"

# Custom infotype supports only regex based prediction for column values
if custom_infotype_config.Prediction_Factors_and_Weights.Values > 0:
assert custom_infotype_config.Values
assert (
custom_infotype_config.Values.prediction_type
== ValuePredictionType.REGEX
), f"Invalid Prediction Type for Values for Custom Info Type {custom_infotype}. Only `regex` is supported."

return info_types_config


class DataHubClassifier(Classifier):
Expand Down
110 changes: 110 additions & 0 deletions metadata-ingestion/tests/integration/snowflake/snowflake_golden.json
Original file line number Diff line number Diff line change
Expand Up @@ -318,6 +318,17 @@
},
"nativeDataType": "VARCHAR(255)",
"recursive": false,
"glossaryTerms": {
"terms": [
{
"urn": "urn:li:glossaryTerm:CloudRegion"
}
],
"auditStamp": {
"time": 1654621200000,
"actor": "urn:li:corpuser:datahub"
}
},
"isPartOfKey": false
},
{
Expand Down Expand Up @@ -576,6 +587,17 @@
},
"nativeDataType": "VARCHAR(255)",
"recursive": false,
"glossaryTerms": {
"terms": [
{
"urn": "urn:li:glossaryTerm:CloudRegion"
}
],
"auditStamp": {
"time": 1654621200000,
"actor": "urn:li:corpuser:datahub"
}
},
"isPartOfKey": false
},
{
Expand Down Expand Up @@ -834,6 +856,17 @@
},
"nativeDataType": "VARCHAR(255)",
"recursive": false,
"glossaryTerms": {
"terms": [
{
"urn": "urn:li:glossaryTerm:CloudRegion"
}
],
"auditStamp": {
"time": 1654621200000,
"actor": "urn:li:corpuser:datahub"
}
},
"isPartOfKey": false
},
{
Expand Down Expand Up @@ -1092,6 +1125,17 @@
},
"nativeDataType": "VARCHAR(255)",
"recursive": false,
"glossaryTerms": {
"terms": [
{
"urn": "urn:li:glossaryTerm:CloudRegion"
}
],
"auditStamp": {
"time": 1654621200000,
"actor": "urn:li:corpuser:datahub"
}
},
"isPartOfKey": false
},
{
Expand Down Expand Up @@ -1350,6 +1394,17 @@
},
"nativeDataType": "VARCHAR(255)",
"recursive": false,
"glossaryTerms": {
"terms": [
{
"urn": "urn:li:glossaryTerm:CloudRegion"
}
],
"auditStamp": {
"time": 1654621200000,
"actor": "urn:li:corpuser:datahub"
}
},
"isPartOfKey": false
},
{
Expand Down Expand Up @@ -1608,6 +1663,17 @@
},
"nativeDataType": "VARCHAR(255)",
"recursive": false,
"glossaryTerms": {
"terms": [
{
"urn": "urn:li:glossaryTerm:CloudRegion"
}
],
"auditStamp": {
"time": 1654621200000,
"actor": "urn:li:corpuser:datahub"
}
},
"isPartOfKey": false
},
{
Expand Down Expand Up @@ -1866,6 +1932,17 @@
},
"nativeDataType": "VARCHAR(255)",
"recursive": false,
"glossaryTerms": {
"terms": [
{
"urn": "urn:li:glossaryTerm:CloudRegion"
}
],
"auditStamp": {
"time": 1654621200000,
"actor": "urn:li:corpuser:datahub"
}
},
"isPartOfKey": false
},
{
Expand Down Expand Up @@ -2124,6 +2201,17 @@
},
"nativeDataType": "VARCHAR(255)",
"recursive": false,
"glossaryTerms": {
"terms": [
{
"urn": "urn:li:glossaryTerm:CloudRegion"
}
],
"auditStamp": {
"time": 1654621200000,
"actor": "urn:li:corpuser:datahub"
}
},
"isPartOfKey": false
},
{
Expand Down Expand Up @@ -2382,6 +2470,17 @@
},
"nativeDataType": "VARCHAR(255)",
"recursive": false,
"glossaryTerms": {
"terms": [
{
"urn": "urn:li:glossaryTerm:CloudRegion"
}
],
"auditStamp": {
"time": 1654621200000,
"actor": "urn:li:corpuser:datahub"
}
},
"isPartOfKey": false
},
{
Expand Down Expand Up @@ -2640,6 +2739,17 @@
},
"nativeDataType": "VARCHAR(255)",
"recursive": false,
"glossaryTerms": {
"terms": [
{
"urn": "urn:li:glossaryTerm:CloudRegion"
}
],
"auditStamp": {
"time": 1654621200000,
"actor": "urn:li:corpuser:datahub"
}
},
"isPartOfKey": false
},
{
Expand Down
34 changes: 31 additions & 3 deletions metadata-ingestion/tests/integration/snowflake/test_snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
DataHubClassifierConfig,
InfoTypeConfig,
PredictionFactorsAndWeights,
ValuesFactorConfig,
)
from datahub.ingestion.run.pipeline import Pipeline
from datahub.ingestion.run.pipeline_config import PipelineConfig, SourceConfig
Expand All @@ -42,6 +43,18 @@ def random_email():
)


def random_cloud_region():
return "".join(
[
random.choice(["af", "ap", "ca", "eu", "me", "sa", "us"]),
"-",
random.choice(["central", "north", "south", "east", "west"]),
"-",
str(random.randint(1, 2)),
]
)


@freeze_time(FROZEN_TIME)
@pytest.mark.integration
def test_snowflake_basic(pytestconfig, tmp_path, mock_time, mock_datahub_graph):
Expand All @@ -63,8 +76,9 @@ def test_snowflake_basic(pytestconfig, tmp_path, mock_time, mock_datahub_graph):

mock_sample_values.return_value = pd.DataFrame(
data={
"col_1": [random.randint(0, 100) for i in range(1, 200)],
"col_2": [random_email() for i in range(1, 200)],
"col_1": [random.randint(0, 100) for i in range(1, 100)],
"col_2": [random_email() for i in range(1, 100)],
"col_3": [random_cloud_region() for i in range(1, 100)],
}
)

Expand All @@ -76,6 +90,20 @@ def test_snowflake_basic(pytestconfig, tmp_path, mock_time, mock_datahub_graph):
Name=0, Values=1, Description=0, Datatype=0
)
),
"CloudRegion": InfoTypeConfig(
Prediction_Factors_and_Weights=PredictionFactorsAndWeights(
Name=0,
Description=0,
Datatype=0,
Values=1,
),
Values=ValuesFactorConfig(
prediction_type="regex",
regex=[
r"(af|ap|ca|eu|me|sa|us)-(central|north|(north(?:east|west))|south|south(?:east|west)|east|west)-\d+"
],
),
),
}
pipeline = Pipeline(
config=PipelineConfig(
Expand Down Expand Up @@ -103,7 +131,7 @@ def test_snowflake_basic(pytestconfig, tmp_path, mock_time, mock_datahub_graph):
classification=ClassificationConfig(
enabled=True,
column_pattern=AllowDenyPattern(
allow=[".*col_1$", ".*col_2$"]
allow=[".*col_1$", ".*col_2$", ".*col_3$"]
),
classifiers=[
DynamicTypedClassifierConfig(
Expand Down
Loading