diff --git a/src/sdk/python/rtdip_sdk/pipelines/data_wranglers/spark/data_quality/interval_filtering.py b/src/sdk/python/rtdip_sdk/pipelines/data_wranglers/spark/data_quality/interval_filtering.py index a7a768064..e0b353772 100644 --- a/src/sdk/python/rtdip_sdk/pipelines/data_wranglers/spark/data_quality/interval_filtering.py +++ b/src/sdk/python/rtdip_sdk/pipelines/data_wranglers/spark/data_quality/interval_filtering.py @@ -40,11 +40,12 @@ class IntervalFiltering(WranglerBaseInterface): DEFAULT_TIME_STAMP_COLUMN_NAME: str = "EventTime" - def __init__(self, spark: SparkSession, df: DataFrame, interval: int, interval_unit: str, time_stamp_column_name: str = None) -> None: + def __init__(self, spark: SparkSession, df: DataFrame, interval: int, interval_unit: str, time_stamp_column_name: str = None, tolerance: int = None) -> None: self.spark = spark self.df = df self.interval = interval self.interval_unit = interval_unit + self.tolerance = tolerance if time_stamp_column_name is None: self.time_stamp_column_name = self.DEFAULT_TIME_STAMP_COLUMN_NAME else: self.time_stamp_column_name = time_stamp_column_name @@ -73,20 +74,26 @@ def convert_column_to_timestamp(self) -> DataFrame: except Exception as e: raise ValueError(f"Error converting column {self.time_stamp_column_name} to timestamp: {e}") - def get_time_delta(self) -> timedelta: + def get_time_delta(self, value: int) -> timedelta: if self.interval_unit == 'minutes': - return timedelta(minutes = self.interval) + return timedelta(minutes = value) elif self.interval_unit == 'days': - return timedelta(days = self.interval) + return timedelta(days = value) elif self.interval_unit == 'hours': - return timedelta(hours = self.interval) + return timedelta(hours = value) elif self.interval_unit == 'seconds': - return timedelta(seconds = self.interval) + return timedelta(seconds = value) elif self.interval_unit == 'milliseconds': - return timedelta(milliseconds = self.interval) + return timedelta(milliseconds = value) else: raise ValueError("interval_unit must be either 'days', 'hours', 'minutes', 'seconds' or 'milliseconds'") + def check_if_outside_of_interval(self, current_time_stamp: pd.Timestamp, last_time_stamp: pd.Timestamp, time_delta_in_ms: float, tolerance_in_ms: float) -> bool: + if tolerance_in_ms is None: + return ((current_time_stamp - last_time_stamp).total_seconds() * 1000) >= time_delta_in_ms + else: + return ((current_time_stamp - last_time_stamp).total_seconds() * 1000) + tolerance_in_ms >= time_delta_in_ms + def format_date_time_to_string(self, time_stamp: pd.Timestamp) -> str: try: return time_stamp.strftime('%Y-%m-%d %H:%M:%S.%f')[:-3] @@ -104,7 +111,13 @@ def filter(self) -> DataFrame: original_schema = self.df.schema self.df = self.convert_column_to_timestamp().orderBy(self.time_stamp_column_name) - time_delta_in_ms = self.get_time_delta().total_seconds() * 1000 + tolerance_in_ms = None + if self.tolerance is not None: + tolerance_in_ms = self.get_time_delta(self.tolerance).total_seconds() * 1000 + print(tolerance_in_ms) + + + time_delta_in_ms = self.get_time_delta(self.interval).total_seconds() * 1000 rows = self.df.collect() last_time_stamp = rows[0][self.time_stamp_column_name] @@ -117,7 +130,7 @@ def filter(self) -> DataFrame: current_row = rows[i] current_time_stamp = current_row[self.time_stamp_column_name] - if ((current_time_stamp - last_time_stamp).total_seconds() * 1000) >= time_delta_in_ms: + if self.check_if_outside_of_interval(current_time_stamp, last_time_stamp, time_delta_in_ms, tolerance_in_ms): current_row_dict = current_row.asDict() current_row_dict[self.time_stamp_column_name] = self.format_date_time_to_string(current_row_dict[self.time_stamp_column_name]) cleansed_df.append(current_row_dict) diff --git a/tests/sdk/python/rtdip_sdk/pipelines/data_wranglers/spark/data_quality/test_interval_filtering.py b/tests/sdk/python/rtdip_sdk/pipelines/data_wranglers/spark/data_quality/test_interval_filtering.py index 1379d7929..8205dcad5 100644 --- a/tests/sdk/python/rtdip_sdk/pipelines/data_wranglers/spark/data_quality/test_interval_filtering.py +++ b/tests/sdk/python/rtdip_sdk/pipelines/data_wranglers/spark/data_quality/test_interval_filtering.py @@ -249,3 +249,34 @@ def test_interval_detection_faulty_time_stamp(spark_session: SparkSession): with pytest.raises(ValueError): interval_filtering_wrangler.filter() +def test_interval_tolerance(spark_session: SparkSession): + expected_df = spark_session.createDataFrame( + [ + ("A2PS64V0J.:ZUX09R", "2024-01-02 03:49:45.000", "Good", "0.129999995"), + ("A2PS64V0J.:ZUX09R", "2024-01-02 03:49:47.000", "Good", "0.129999995"), + ("A2PS64V0J.:ZUX09R", "2024-01-02 03:49:50.000", "Good", "0.129999995"), + ("A2PS64V0J.:ZUX09R", "2024-01-02 03:49:52.000", "Good", "0.129999995"), + ], + ["TagName", "EventTime", "Status", "Value"], + ) + + df = spark_session.createDataFrame( + [ + ("A2PS64V0J.:ZUX09R", "2024-01-02 03:49:45.000", "Good", "0.129999995"), + ("A2PS64V0J.:ZUX09R", "2024-01-02 03:49:46.000", "Good", "0.129999995"), + ("A2PS64V0J.:ZUX09R", "2024-01-02 03:49:47.000", "Good", "0.129999995"), + ("A2PS64V0J.:ZUX09R", "2024-01-02 03:49:50.000", "Good", "0.129999995"), + ("A2PS64V0J.:ZUX09R", "2024-01-02 03:49:51.000", "Good", "0.129999995"), + ("A2PS64V0J.:ZUX09R", "2024-01-02 03:49:52.000", "Good", "0.129999995"), + + ], + ["TagName", "EventTime", "Status", "Value"], + ) + + interval_filtering_wrangler = IntervalFiltering(spark_session, df, 3, "seconds", "EventTime", 1) + actual_df = interval_filtering_wrangler.filter() + + assert expected_df.columns == actual_df.columns + assert expected_df.schema == actual_df.schema + assert expected_df.collect() == actual_df.collect() +