Skip to content

Commit

Permalink
Merge pull request #95 from amosproj/refactor/068_067_069
Browse files Browse the repository at this point in the history
refactoring unit tests
  • Loading branch information
dh1542 authored Dec 10, 2024
2 parents f1e99ac + 9fd3825 commit 2cd1d9b
Show file tree
Hide file tree
Showing 10 changed files with 1,497 additions and 465 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,13 @@
import logging
from pyspark.sql import DataFrame as PySparkDataFrame
from pyspark.sql.functions import col
from pyspark.sql.types import (
StructType,
StructField,
StringType,
TimestampType,
FloatType,
)
from functools import reduce
from operator import or_

Expand All @@ -30,18 +37,17 @@

class CheckValueRanges(MonitoringBaseInterface, InputValidator):
"""
Monitors data in a DataFrame by checking specified columns against expected value ranges.
Logs events when values exceed the specified ranges.
Monitors data in a DataFrame by checking the 'Value' column against expected ranges for specified TagNames.
Logs events when 'Value' exceeds the defined ranges for any TagName.
Args:
df (pyspark.sql.DataFrame): The DataFrame to monitor.
columns_ranges (dict): A dictionary where keys are column names and values are dictionaries specifying 'min' and/or
tag_ranges (dict): A dictionary where keys are TagNames and values are dictionaries specifying 'min' and/or
'max', and optionally 'inclusive_bounds' values.
Example:
{
'temperature': {'min': 0, 'max': 100, 'inclusive_bounds': True},
'pressure': {'min': 10, 'max': 200, 'inclusive_bounds': False},
'humidity': {'min': 30} # Defaults to inclusive_bounds = True
'A2PS64V0J.:ZUX09R': {'min': 0, 'max': 100, 'inclusive_bounds': True},
'B3TS64V0K.:ZUX09R': {'min': 10, 'max': 200, 'inclusive_bounds': False},
}
Example:
Expand All @@ -52,41 +58,50 @@ class CheckValueRanges(MonitoringBaseInterface, InputValidator):
spark = SparkSession.builder.master("local[1]").appName("CheckValueRangesExample").getOrCreate()
data = [
(1, 25, 100),
(2, -5, 150),
(3, 50, 250),
(4, 80, 300),
(5, 100, 50),
("A2PS64V0J.:ZUX09R", "2024-01-02 03:49:45.000", "Good", 25.0),
("A2PS64V0J.:ZUX09R", "2024-01-02 07:53:11.000", "Good", -5.0),
("A2PS64V0J.:ZUX09R", "2024-01-02 11:56:42.000", "Good", 50.0),
("B3TS64V0K.:ZUX09R", "2024-01-02 16:00:12.000", "Good", 80.0),
("A2PS64V0J.:ZUX09R", "2024-01-02 20:03:46.000", "Good", 100.0),
]
columns = ["ID", "temperature", "pressure"]
columns = ["TagName", "EventTime", "Status", "Value"]
df = spark.createDataFrame(data, columns)
columns_ranges = {
"temperature": {"min": 0, "max": 100, "inclusive_bounds": False},
"pressure": {"min": 50, "max": 200},
tag_ranges = {
"A2PS64V0J.:ZUX09R": {"min": 0, "max": 50, "inclusive_bounds": True},
"B3TS64V0K.:ZUX09R": {"min": 50, "max": 100, "inclusive_bounds": False},
}
check_value_ranges = CheckValueRanges(
df=df,
columns_ranges=columns_ranges,
tag_ranges=tag_ranges,
)
result_df = check_value_ranges.check()
```
"""

df: PySparkDataFrame
tag_ranges: dict
EXPECTED_SCHEMA = StructType(
[
StructField("TagName", StringType(), True),
StructField("EventTime", TimestampType(), True),
StructField("Status", StringType(), True),
StructField("Value", FloatType(), True),
]
)

def __init__(
self,
df: PySparkDataFrame,
columns_ranges: dict,
tag_ranges: dict,
) -> None:

self.df = df
self.columns_ranges = columns_ranges
self.validate(self.EXPECTED_SCHEMA)
self.tag_ranges = tag_ranges

# Configure logging
self.logger = logging.getLogger(self.__class__.__name__)
Expand Down Expand Up @@ -118,17 +133,22 @@ def settings() -> dict:

def check(self) -> PySparkDataFrame:
"""
Executes the value range checking logic. Identifies and logs any rows where specified
columns exceed their defined value ranges.
Executes the value range checking logic for the specified TagNames. Identifies and logs any rows
where 'Value' exceeds the defined ranges for each TagName.
Returns:
pyspark.sql.DataFrame:
Returns the original PySpark DataFrame without changes.
"""
self._validate_inputs()
df = self.df

for column, range_dict in self.columns_ranges.items():
for tag_name, range_dict in self.tag_ranges.items():
df = self.df.filter(col("TagName") == tag_name)

if df.count() == 0:
self.logger.warning(f"No data found for TagName '{tag_name}'.")
continue

min_value = range_dict.get("min", None)
max_value = range_dict.get("max", None)
inclusive_bounds = range_dict.get("inclusive_bounds", True)
Expand All @@ -138,17 +158,17 @@ def check(self) -> PySparkDataFrame:
# Build minimum value condition
if min_value is not None:
if inclusive_bounds:
min_condition = col(column) < min_value
min_condition = col("Value") < min_value
else:
min_condition = col(column) <= min_value
min_condition = col("Value") <= min_value
conditions.append(min_condition)

# Build maximum value condition
if max_value is not None:
if inclusive_bounds:
max_condition = col(column) > max_value
max_condition = col("Value") > max_value
else:
max_condition = col(column) >= max_value
max_condition = col("Value") >= max_value
conditions.append(max_condition)

if not conditions:
Expand All @@ -160,31 +180,59 @@ def check(self) -> PySparkDataFrame:
count = out_of_range_df.count()
if count > 0:
self.logger.info(
f"Found {count} rows in column '{column}' out of range."
f"Found {count} rows in 'Value' column for TagName '{tag_name}' out of range."
)
out_of_range_rows = out_of_range_df.collect()
for row in out_of_range_rows:
self.logger.info(f"Out of range row in column '{column}': {row}")
self.logger.info(
f"Out of range row for TagName '{tag_name}': {row}"
)
else:
self.logger.info(f"No out of range values found in column '{column}'.")
self.logger.info(
f"No out of range values found in 'Value' column for TagName '{tag_name}'."
)

return self.df

def _validate_inputs(self):
if not isinstance(self.columns_ranges, dict):
raise TypeError("columns_ranges must be a dictionary.")
if not isinstance(self.tag_ranges, dict):
raise TypeError("tag_ranges must be a dictionary.")

# Erstelle eine Liste aller verfügbaren TagNames im DataFrame
available_tags = (
self.df.select("TagName").distinct().rdd.map(lambda row: row[0]).collect()
)

for column, range_dict in self.columns_ranges.items():
if column not in self.df.columns:
raise ValueError(f"Column '{column}' not found in DataFrame.")
for tag_name, range_dict in self.tag_ranges.items():
# Überprüfung, ob der TagName ein gültiger String ist
if not isinstance(tag_name, str):
raise ValueError(f"TagName '{tag_name}' must be a string.")

# Überprüfung, ob der TagName im DataFrame existiert
if tag_name not in available_tags:
raise ValueError(f"TagName '{tag_name}' not found in DataFrame.")

# Überprüfung, ob min und/oder max angegeben sind
if "min" not in range_dict and "max" not in range_dict:
raise ValueError(
f"TagName '{tag_name}' must have at least 'min' or 'max' specified."
)

# Überprüfung, ob inclusive_bounds ein boolescher Wert ist
inclusive_bounds = range_dict.get("inclusive_bounds", True)
if not isinstance(inclusive_bounds, bool):
raise ValueError(
f"Inclusive_bounds for column '{column}' must be a boolean."
f"Inclusive_bounds for TagName '{tag_name}' must be a boolean."
)

if "min" not in range_dict and "max" not in range_dict:
# Optionale Überprüfung, ob min und max numerisch sind
min_value = range_dict.get("min", None)
max_value = range_dict.get("max", None)
if min_value is not None and not isinstance(min_value, (int, float)):
raise ValueError(
f"Minimum value for TagName '{tag_name}' must be a number."
)
if max_value is not None and not isinstance(max_value, (int, float)):
raise ValueError(
f"Column '{column}' must have at least 'min' or 'max' specified."
f"Maximum value for TagName '{tag_name}' must be a number."
)
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,15 @@

import logging
from pyspark.sql import DataFrame as PySparkDataFrame
from pyspark.sql.functions import col, when, lag, count, sum
from pyspark.sql.functions import col, when, lag, sum, abs
from pyspark.sql.window import Window
from pyspark.sql.types import (
StructType,
StructField,
StringType,
TimestampType,
FloatType,
)

from src.sdk.python.rtdip_sdk.pipelines.data_quality.monitoring.interfaces import (
MonitoringBaseInterface,
Expand Down Expand Up @@ -72,6 +79,14 @@ class FlatlineDetection(MonitoringBaseInterface, InputValidator):
df: PySparkDataFrame
watch_columns: list
tolerance_timespan: int
EXPECTED_SCHEMA = StructType(
[
StructField("TagName", StringType(), True),
StructField("EventTime", TimestampType(), True),
StructField("Status", StringType(), True),
StructField("Value", FloatType(), True),
]
)

def __init__(
self, df: PySparkDataFrame, watch_columns: list, tolerance_timespan: int
Expand All @@ -82,6 +97,7 @@ def __init__(
raise ValueError("tolerance_timespan must be a positive integer.")

self.df = df
self.validate(self.EXPECTED_SCHEMA)
self.watch_columns = watch_columns
self.tolerance_timespan = tolerance_timespan

Expand Down Expand Up @@ -113,58 +129,37 @@ def settings() -> dict:
return {}

def check(self) -> PySparkDataFrame:
"""
Detects flatlining in the specified columns and logs warnings if detected.
Returns:
pyspark.sql.DataFrame: The original PySpark DataFrame unchanged.
"""
sort_column = self.df.columns[0]

partition_column = "TagName"
sort_column = "EventTime"
window_spec = Window.partitionBy(partition_column).orderBy(sort_column)
for column in self.watch_columns:
# Flag null or zero values
flagged_column = f"{column}_flatline_flag"
flagged_df = self.df.withColumn(
self.df = self.df.withColumn(
flagged_column,
when((col(column).isNull()) | (col(column) == 0), 1).otherwise(0),
when((col(column).isNull()) | (col(column) == 0.0), 1).otherwise(0),
)

# Create a group for consecutive flatline streaks
group_column = f"{column}_group"
flagged_df = flagged_df.withColumn(
self.df = self.df.withColumn(
group_column,
(
col(flagged_column)
!= lag(col(flagged_column), 1, 0).over(Window.orderBy(sort_column))
).cast("int"),
)
flagged_df = flagged_df.withColumn(
group_column, sum(col(group_column)).over(Window.orderBy(sort_column))
sum(
when(
col(flagged_column)
!= lag(col(flagged_column), 1, 0).over(window_spec),
1,
).otherwise(0)
).over(window_spec),
)

# Count rows in each group
group_counts = (
flagged_df.filter(col(flagged_column) == 1)
.groupBy(group_column)
.count()
self.df.filter(col(flagged_column) == 1).groupBy(group_column).count()
)

# Filter groups that exceed the tolerance
large_groups = group_counts.filter(col("count") > self.tolerance_timespan)

# Log all rows in groups exceeding tolerance
if large_groups.count() > 0:
large_group_ids = [row[group_column] for row in large_groups.collect()]
relevant_rows = (
flagged_df.filter(col(group_column).isin(large_group_ids))
.select(*self.df.columns)
.collect()
)
for row in relevant_rows:
large_group_ids = [row[group_column] for row in large_groups.collect()]
if large_group_ids:
relevant_rows = self.df.filter(col(group_column).isin(large_group_ids))
for row in relevant_rows.collect():
self.logger.warning(
f"Flatlining detected in column '{column}' at row: {row}."
)
else:
self.logger.info(f"No flatlining detected in column '{column}'.")

return self.df
Loading

0 comments on commit 2cd1d9b

Please sign in to comment.