Skip to content

Commit

Permalink
comments ruben
Browse files Browse the repository at this point in the history
  • Loading branch information
abontsema committed Aug 18, 2022
1 parent 1d03c8d commit 9930404
Show file tree
Hide file tree
Showing 6 changed files with 32 additions and 35 deletions.
4 changes: 2 additions & 2 deletions sam/validation/base_validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@ def fit(self, X, y=None):
def transform(self, X: pd.DataFrame) -> pd.DataFrame:
"""transform method"""
X = X.copy()
invalids = self.validate(X)
X[invalids] = np.nan
invalid_data = self.validate(X)
X[invalid_data] = np.nan
return X

def get_feature_names_out(self, input_features=None) -> List[str]:
Expand Down
14 changes: 6 additions & 8 deletions sam/validation/flatline_validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import numpy as np
import pandas as pd
from sam.utils import add_future_warning
from sam.validation import BaseValidator

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -138,29 +139,26 @@ def validate(self, X: pd.DataFrame) -> pd.DataFrame:
X: pd.DataFrame
Input dataframe to validate
"""
invalids = pd.DataFrame(
invalid_data = pd.DataFrame(
data=np.zeros_like(X.values).astype(bool),
index=X.index,
columns=X.columns,
)

for col in self.cols:
window = self.window_dict[col]
invalids[col] = self._validate_column(X[col], window)
invalid_data[col] = self._validate_column(X[col], window)

logger.info(
f"detected {np.sum(invalids[col])} "
f"detected {np.sum(invalid_data[col])} "
f"flatline samples in {col} "
f"with window of {window} "
)

return invalids
return invalid_data


class RemoveFlatlines(FlatlineValidator):
@add_future_warning("RemoveFlatlines is deprecated, use FlatlineValidator instead")
def __init__(self, *args, **kwargs):
warnings.warn(
"RemoveFlatlines is deprecated, use FlatlineValidator instead",
DeprecationWarning,
)
super().__init__(*args, **kwargs)
23 changes: 10 additions & 13 deletions sam/validation/mad_validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import numpy as np
import pandas as pd
from sam.validation import BaseValidator
from sam.utils import add_future_warning

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -141,7 +142,7 @@ def validate(self, X: pd.DataFrame):
input data with columns marked as nan
"""

invalids = pd.DataFrame(
invalid_data = pd.DataFrame(
data=np.zeros_like(X.values).astype(bool),
index=X.index,
columns=X.columns,
Expand All @@ -161,23 +162,19 @@ def validate(self, X: pd.DataFrame):

# log number of values removed and tresholds used
logger.info(
"detected %d " % np.sum(extreme_value)
+ "extreme values from %s. " % c
+ "using upper threshold of: %.2f " % self.thresh_high[c]
+ "and lower threshold of: %.2f " % self.thresh_low[c]
+ "using madthresh of %d " % self.madthresh
+ "and rollingwindow of %s" % str(self.rollingwindow)
f"detected {np.sum(extreme_value)} extreme values from {c}. "
f"using upper threshold of: {round(self.thresh_high[c], 2)} "
f"and lower threshold of: {round(self.thresh_low[c])} "
f"using madthresh of {self.madthresh} "
f"and rollingwindow of {str(self.rollingwindow)}"
)

invalids[c] = extreme_value
invalid_data[c] = extreme_value

return invalids
return invalid_data


class RemoveExtremeValues(MADValidator):
@add_future_warning("RemoveExtremeValues is deprecated. Use MADValidator instead.")
def __init__(self, *args, **kwargs):
warnings.warn(
"RemoveExtremeValues is deprecated. Use MADValidator instead.",
DeprecationWarning,
)
super().__init__(*args, **kwargs)
8 changes: 5 additions & 3 deletions sam/validation/outside_range_validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,12 +79,14 @@ def validate(self, X):
Dataframe containing the features to be checked.
"""

invalids = pd.DataFrame(
invalid_data = pd.DataFrame(
data=np.zeros_like(X.values).astype(bool),
index=X.index,
columns=X.columns,
)

invalids[self.cols] = X[self.cols].gt(self.max_value_) | X[self.cols].lt(self.min_value_)
invalid_data[self.cols] = X[self.cols].gt(self.max_value_) | X[self.cols].lt(
self.min_value_
)

return invalids
return invalid_data
6 changes: 3 additions & 3 deletions sam/visualization/diagnostic_flatline_removal.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@


def diagnostic_flatline_removal(
fv: FlatlineValidator,
flatline_validator: FlatlineValidator,
raw_data: pd.DataFrame,
col: str,
):
Expand All @@ -12,7 +12,7 @@ def diagnostic_flatline_removal(
Parameters:
----------
fv: sam.validation.FlatlineValidator
flatline_validator: sam.validation.FlatlineValidator
fitted FlatlineValidator object
raw_data: pd.DataFrame
non-transformed data
Expand All @@ -28,7 +28,7 @@ def diagnostic_flatline_removal(

# get data
x = raw_data[col].copy()
invalid_w = fv.validate(raw_data)[col]
invalid_w = flatline_validator.validate(raw_data)[col]
invalid_values = x[invalid_w]

# generate plot
Expand Down
12 changes: 6 additions & 6 deletions sam/visualization/extreme_removal_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@


def diagnostic_extreme_removal(
madv: MADValidator,
mad_validator: MADValidator,
raw_data: pd.DataFrame,
col: str,
):
Expand All @@ -12,7 +12,7 @@ def diagnostic_extreme_removal(
Parameters:
----------
rev: sam.validation.MADValidator
mad_validator: sam.validation.MADValidator
fitted MADValidator object
raw_data: pd.DataFrame
non-transformed data data
Expand All @@ -29,9 +29,9 @@ def diagnostic_extreme_removal(

# get data
x = raw_data[col].copy()
invalid_w = madv.validate(raw_data)[col]
invalid_w = mad_validator.validate(raw_data)[col]
invalid_values = x.loc[invalid_w]
rolling = madv._compute_rolling(x)
rolling = mad_validator._compute_rolling(x)
diff = x.values - rolling

# generate plot
Expand Down Expand Up @@ -65,8 +65,8 @@ def diagnostic_extreme_removal(

plt.subplot(212)
plt.plot(diff.values, label="abs(original - rolling)")
plt.axhline(madv.thresh_high[col], ls="--", c="r")
plt.axhline(madv.thresh_low[col], ls="--", c="r", label="thresholds")
plt.axhline(mad_validator.thresh_high[col], ls="--", c="r")
plt.axhline(mad_validator.thresh_low[col], ls="--", c="r", label="thresholds")
plt.legend(loc="best")
sns.despine()
plt.tight_layout()
Expand Down

0 comments on commit 9930404

Please sign in to comment.