Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

#60 unified input data validation #87

Merged
merged 2 commits into from
Dec 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,14 @@
from pyspark.sql import DataFrame as PySparkDataFrame

from ..interfaces import DataManipulationBaseInterface
from ...input_validator import InputValidator
from src.sdk.python.rtdip_sdk.pipelines._pipeline_utils.models import (
Libraries,
SystemType,
)


class DuplicateDetection(DataManipulationBaseInterface):
class DuplicateDetection(DataManipulationBaseInterface, InputValidator):
"""
Cleanses a PySpark DataFrame from duplicates.

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,10 @@
SystemType,
)
from ..interfaces import DataManipulationBaseInterface
from ...input_validator import InputValidator


class IntervalFiltering(DataManipulationBaseInterface):
class IntervalFiltering(DataManipulationBaseInterface, InputValidator):
"""
Cleanses a DataFrame by removing rows outside a specified interval window. Supported time stamp columns are DateType and StringType.

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,14 @@
from pyspark.sql import DataFrame, SparkSession
from pyspark.sql.functions import mean, stddev, abs, col
from ..interfaces import DataManipulationBaseInterface
from ...input_validator import InputValidator
from src.sdk.python.rtdip_sdk.pipelines._pipeline_utils.models import (
Libraries,
SystemType,
)


class KSigmaAnomalyDetection(DataManipulationBaseInterface):
class KSigmaAnomalyDetection(DataManipulationBaseInterface, InputValidator):
"""
Anomaly detection with the k-sigma method. This method either computes the mean and standard deviation, or the median and the median absolute deviation (MAD) of the data.
The k-sigma method then filters out all data points that are k times the standard deviation away from the mean, or k times the MAD away from the median.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,14 @@
from datetime import timedelta
from typing import List
from ..interfaces import DataManipulationBaseInterface
from ...input_validator import InputValidator
from src.sdk.python.rtdip_sdk.pipelines._pipeline_utils.models import (
Libraries,
SystemType,
)


class MissingValueImputation(DataManipulationBaseInterface):
class MissingValueImputation(DataManipulationBaseInterface, InputValidator):
"""
Imputes missing values in a univariate time series creating a continuous curve of data points. For that, the
time intervals of each individual source is calculated, to then insert empty records at the missing timestamps with
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from pyspark.sql import DataFrame as PySparkDataFrame
from ....input_validator import InputValidator
from src.sdk.python.rtdip_sdk.pipelines.data_quality.data_manipulation.interfaces import (
DataManipulationBaseInterface,
)
Expand All @@ -24,7 +25,7 @@
)


class Denormalization(DataManipulationBaseInterface):
class Denormalization(DataManipulationBaseInterface, InputValidator):
"""
#TODO
Applies the appropriate denormalization method to revert values to their original scale.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from abc import abstractmethod
from pyspark.sql import DataFrame as PySparkDataFrame
from typing import List
from ....input_validator import InputValidator
from src.sdk.python.rtdip_sdk.pipelines.data_quality.data_manipulation.interfaces import (
DataManipulationBaseInterface,
)
Expand All @@ -23,7 +24,7 @@
)


class NormalizationBaseClass(DataManipulationBaseInterface):
class NormalizationBaseClass(DataManipulationBaseInterface, InputValidator):
"""
A base class for applying normalization techniques to multiple columns in a PySpark DataFrame.
This class serves as a framework to support various normalization methods (e.g., Z-Score, Min-Max, and Mean),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,14 @@
from pmdarima import auto_arima

from ...interfaces import DataManipulationBaseInterface
from ....input_validator import InputValidator
from src.sdk.python.rtdip_sdk.pipelines._pipeline_utils.models import (
Libraries,
SystemType,
)


class ArimaPrediction(DataManipulationBaseInterface):
class ArimaPrediction(DataManipulationBaseInterface, InputValidator):
"""
Extends the timeseries data in given DataFrame with forecasted values from an ARIMA model. Can be optionally set
to use auto_arima, which operates a bit like a grid search, in that it tries various sets of p and q (also P and Q
Expand Down
157 changes: 157 additions & 0 deletions src/sdk/python/rtdip_sdk/pipelines/data_quality/input_validator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
# Copyright 2022 RTDIP
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from pyspark.sql.types import DataType, StructType
from pyspark.sql import functions as F
from ..interfaces import PipelineComponentBaseInterface
from src.sdk.python.rtdip_sdk.pipelines._pipeline_utils.models import (
Libraries,
SystemType,
)


class InputValidator(PipelineComponentBaseInterface):
"""
Validates the PySpark DataFrame of the respective child class instance against a schema dictionary or pyspark
StructType. Checks for column availability and column data types. If data types differ, it tries to cast the
column into the expected data type. Raises Errors if some step fails.

Example:
--------
import pytest
from pyspark.sql import SparkSession
from pyspark.sql.types import StructType, StructField, StringType, TimestampType, FloatType
from src.sdk.python.rtdip_sdk.pipelines.data_quality.data_manipulation.spark.missing_value_imputation import (
MissingValueImputation,
)

@pytest.fixture(scope="session")
def spark_session():
return SparkSession.builder.master("local[2]").appName("test").getOrCreate()

spark = spark_session()

test_schema = StructType(
[
StructField("TagName", StringType(), True),
StructField("EventTime", StringType(), True),
StructField("Status", StringType(), True),
StructField("Value", StringType(), True),
]
)
expected_schema = StructType(
[
StructField("TagName", StringType(), True),
StructField("EventTime", TimestampType(), True),
StructField("Status", StringType(), True),
StructField("Value", FloatType(), True),
]
)

test_data = [
("A2PS64V0J.:ZUX09R", "2024-01-01 03:29:21.000", "Good", "1.0"),
("A2PS64V0J.:ZUX09R", "2024-01-01 07:32:55.000", "Good", "2.0"),
("A2PS64V0J.:ZUX09R", "2024-01-01 11:36:29.000", "Good", "3.0"),
]

test_df = spark_session.createDataFrame(test_data, schema=test_schema)
test_component = MissingValueImputation(spark_session, test_df)

print(test_component.validate(expected_schema)) # True

```

Parameters:
schema_dict: dict or pyspark StructType
A dictionary where keys are column names, and values are expected PySpark data types.
Example: {"column1": StringType(), "column2": IntegerType()}

Returns:
True: if data is valid
Raises Error else

Raises:
ValueError: If a column is missing or has a mismatched pyspark data type.
TypeError: If a column does not hold or specify a pyspark data type.
"""

@staticmethod
def system_type():
"""
Attributes:
SystemType (Environment): Requires PYSPARK
"""
return SystemType.PYSPARK

@staticmethod
def libraries():
libraries = Libraries()
return libraries

@staticmethod
def settings() -> dict:
return {}

def validate(self, schema_dict):
"""
Used by child data quality utility classes to validate the input data.
"""
dataframe = getattr(self, "df", None)

if isinstance(schema_dict, StructType):
schema_dict = {field.name: field.dataType for field in schema_dict.fields}

dataframe_schema = {
field.name: field.dataType for field in dataframe.schema.fields
}

for column, expected_type in schema_dict.items():
# Check if the column exists
if column not in dataframe_schema:
raise ValueError(f"Column '{column}' is missing in the DataFrame.")

# Check if both types are of a pyspark data type
actual_type = dataframe_schema[column]
if not isinstance(actual_type, DataType) or not isinstance(
expected_type, DataType
):
raise TypeError(
"Expected and actual types must be instances of pyspark.sql.types.DataType."
)

# Check if actual type is expected type, try to cast else
if not isinstance(actual_type, type(expected_type)):
try:
original_null_count = dataframe.filter(
F.col(column).isNull()
).count()
casted_column = dataframe.withColumn(
column, F.col(column).cast(expected_type)
)
new_null_count = casted_column.filter(
F.col(column).isNull()
).count()

if new_null_count > original_null_count:
raise ValueError(
f"Column '{column}' cannot be cast to {expected_type}."
)
dataframe = casted_column
except Exception as e:
raise ValueError(
f"Error during casting column '{column}' to {expected_type}: {str(e)}"
)

self.df = dataframe
return True
Original file line number Diff line number Diff line change
@@ -1,3 +1,17 @@
# Copyright 2022 RTDIP
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
from pyspark.sql import DataFrame as PySparkDataFrame
from pyspark.sql.functions import col
Expand All @@ -11,9 +25,10 @@
Libraries,
SystemType,
)
from ...input_validator import InputValidator


class CheckValueRanges(MonitoringBaseInterface):
class CheckValueRanges(MonitoringBaseInterface, InputValidator):
"""
Monitors data in a DataFrame by checking specified columns against expected value ranges.
Logs events when values exceed the specified ranges.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,10 @@
Libraries,
SystemType,
)
from ...input_validator import InputValidator


class FlatlineDetection(MonitoringBaseInterface):
class FlatlineDetection(MonitoringBaseInterface, InputValidator):
"""
Detects flatlining in specified columns of a PySpark DataFrame and logs warnings.

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,11 @@
from great_expectations.expectations.expectation import (
ExpectationConfiguration,
)
from ...input_validator import InputValidator


# Create a new context
class GreatExpectationsDataQuality(MonitoringBaseInterface):
class GreatExpectationsDataQuality(MonitoringBaseInterface, InputValidator):
"""
Data Quality Monitoring using Great Expectations allowing you to create and check your data quality expectations.

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,11 @@
from src.sdk.python.rtdip_sdk.pipelines.utilities.spark.time_string_parsing import (
parse_time_string_to_ms,
)

from ...input_validator import InputValidator
from src.sdk.python.rtdip_sdk.pipelines.logging.logger_manager import LoggerManager


class IdentifyMissingDataInterval(MonitoringBaseInterface):
class IdentifyMissingDataInterval(MonitoringBaseInterface, InputValidator):
"""
Detects missing data intervals in a DataFrame by identifying time differences between consecutive
measurements that exceed a specified tolerance or a multiple of the Median Absolute Deviation (MAD).
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,17 @@
# Copyright 2022 RTDIP
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import logging

import pandas as pd
Expand All @@ -6,7 +20,7 @@


from ....logging.logger_manager import LoggerManager

from ...input_validator import InputValidator
from src.sdk.python.rtdip_sdk.pipelines.data_quality.monitoring.interfaces import (
MonitoringBaseInterface,
)
Expand All @@ -19,7 +33,7 @@
)


class IdentifyMissingDataPattern(MonitoringBaseInterface):
class IdentifyMissingDataPattern(MonitoringBaseInterface, InputValidator):
"""
Identifies missing data in a DataFrame based on specified time patterns.
Logs the expected missing times.
Expand Down
Loading
Loading