Skip to content

Commit

Permalink
add pytest 'backwards' function
Browse files Browse the repository at this point in the history
  • Loading branch information
giovanni.buroni committed Dec 1, 2024
1 parent a921003 commit 01f296d
Showing 1 changed file with 20 additions and 1 deletion.
21 changes: 20 additions & 1 deletion tests/test_utils_memory.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import pytest
import pandas as pd
from source.utils.utils_memory import update_period2farm_and_farm2period_train, update_period2farm_and_farm2period_test
from source.utils.utils_memory import update_period2farm_and_farm2period_train, update_period2farm_and_farm2period_test, backwards

# Test case: valid DataFrame
def test_update_period2farm_and_farm2period_train_output(sample_df):
Expand Down Expand Up @@ -57,3 +57,22 @@ def test_update_period2farm_and_farm2period_test_missing_columns(sample_df_missi
# Call the function with a DataFrame missing required columns
with pytest.raises(AssertionError, match="DataFrame df must contain columns 'periodId', 'farmId', and 'power_z'."):
update_period2farm_and_farm2period_test(sample_df_missing_columns)

# Test case: DataFrame output contain lags
def test_backwards_output(sample_df_back_lags):
" Test if the lag features are created correctly. "
# Call the function with a DataFrame
df_lags = backwards(sample_df_back_lags, ['A', 'B'], [1, 2], 2)
# Assert that the lag features are created correctly
print(df_lags.columns.tolist())
assert df_lags.columns.tolist() == ['A', 'B', 'A_lag_-1', 'A_lag_-2', 'B_lag_-1', 'B_lag_-2'], "Lag features are not created correctly."

# Test case: check that no missing values are present
def test_backwards_no_missing_values(sample_df_back_lags):
" Test if the lag features do not contain missing values. "
# Call the function with a DataFrame
df_lags = backwards(sample_df_back_lags, ['A'], [1, 2], 2)
# Assert that there are no missing values in the lag features
print(df_lags)
print(df_lags.isnull().sum())
assert df_lags.isnull().sum().sum() == 0, "Lag features should not contain missing values."

0 comments on commit 01f296d

Please sign in to comment.