Skip to content

Commit

Permalink
#22: Fixed logic and unit mismatches
Browse files Browse the repository at this point in the history
Signed-off-by: Dominik Hoffmann <[email protected]>
  • Loading branch information
dh1542 committed Nov 9, 2024
1 parent 396b49f commit 9ec9b45
Showing 1 changed file with 18 additions and 22 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,10 @@
from datetime import timedelta

import pandas as pd
from databricks.sqlalchemy.test_local.conftest import schema
from pyspark.sql import functions as F
from pyspark.sql import DataFrame, SparkSession
from pyspark.sql import SparkSession
from pyspark.sql import DataFrame

from ...._pipeline_utils.models import Libraries, SystemType
from ...interfaces import WranglerBaseInterface
Expand All @@ -36,6 +38,7 @@ class IntervalFiltering(WranglerBaseInterface):
""" Default time stamp column name if not set in the constructor """
DEFAULT_TIME_STAMP_COLUMN_NAME: str = "EventTime"


def __init__(self, spark: SparkSession, df: DataFrame, interval: int, interval_unit: str, time_stamp_column_name: str = None) -> None:
self.spark = spark
self.df = df
Expand Down Expand Up @@ -83,6 +86,9 @@ def get_time_delta(self) -> timedelta:
else:
raise ValueError("interval_unit must be either 'seconds' or 'milliseconds'")

def format_date_time_to_string(self, time_stamp: pd.Timestamp) -> str:
return time_stamp.strftime('%Y-%m-%d %H:%M:%S.%f')[:-3]

def filter(self) -> DataFrame:
"""
Filters the DataFrame based on the interval
Expand All @@ -91,40 +97,30 @@ def filter(self) -> DataFrame:
if self.time_stamp_column_name not in self.df.columns:
raise ValueError(f"Column {self.time_stamp_column_name} not found in the DataFrame.")

self.df = self.convert_column_to_timestamp()
original_schema = self.df.schema
self.df = self.convert_column_to_timestamp().orderBy(self.time_stamp_column_name)

time_delta = self.get_time_delta()
time_delta_in_ms = self.get_time_delta().total_seconds() * 1000

rows = self.df.collect()
cleansed_df = [rows[0]]



last_time_stamp = rows[0][self.time_stamp_column_name]
first_row = rows[0].asDict()
first_row[self.time_stamp_column_name] = self.format_date_time_to_string(first_row[self.time_stamp_column_name])

cleansed_df = [first_row]

for i in range(1, len(rows)):
current_row = rows[i]
current_time_stamp = current_row[self.time_stamp_column_name]
if ((last_time_stamp - current_time_stamp).total_seconds()) >= time_delta.total_seconds():


cleansed_df.append(current_row)
if ((current_time_stamp - last_time_stamp).total_seconds() * 1000) >= time_delta_in_ms:
current_row_dict = current_row.asDict()
current_row_dict[self.time_stamp_column_name] = self.format_date_time_to_string(current_row_dict[self.time_stamp_column_name])
cleansed_df.append(current_row_dict)
last_time_stamp = current_time_stamp

result_df = self.spark.createDataFrame(cleansed_df, schema= original_schema)

# Create Dataframe from cleansed data
result_df = pd.DataFrame(cleansed_df)

# rename the columns back to original
column_names = self.df.columns
result_df.columns = column_names

# Convert Dataframe time_stamp column back to string
result_df[self.time_stamp_column_name] = result_df[self.time_stamp_column_name].dt.strftime('%Y-%m-%d %H:%M:%S.%f')


print(result_df)
return result_df


Expand Down

0 comments on commit 9ec9b45

Please sign in to comment.