diff --git a/src/flag_and_count_matched_pairs.py b/src/flag_and_count_matched_pairs.py index 9fbc27db..9ab4a480 100644 --- a/src/flag_and_count_matched_pairs.py +++ b/src/flag_and_count_matched_pairs.py @@ -1,7 +1,10 @@ -import pandas as pd -import numpy as np +import numpy as np # noqa F401 +import pandas as pd # noqa F401 -def flag_matched_pair(df, forward_or_backward, target, period, reference, strata, time_difference=1): + +def flag_matched_pair( + df, forward_or_backward, target, period, reference, strata, time_difference=1 +): """ function to flag matched pairs using the shift method @@ -25,20 +28,26 @@ def flag_matched_pair(df, forward_or_backward, target, period, reference, strata Returns ------- _type_ - two pandas dataframes: the main dataframe with column added flagging forward matched pairs and + two pandas dataframes: the main dataframe with column added flagging + forward matched pairs and predictive target variable data column - """ - - df = df.sort_values(by = [reference, period]) + """ + + df = df.sort_values(by=[reference, period]) - if forward_or_backward == 'b': + if forward_or_backward == "b": time_difference = -time_difference - - df[forward_or_backward+"_match"] = df.groupby([strata, reference]).shift(time_difference)[target].notnull().mul(df[target].notnull()) - + + df[forward_or_backward + "_match"] = ( + df.groupby([strata, reference]) + .shift(time_difference)[target] + .notnull() + .mul(df[target].notnull()) + ) + return df - - + + def count_matches(df, flag, period, strata): """ function to flag matched pairs using the shift method @@ -48,7 +57,8 @@ def count_matches(df, flag, period, strata): df : pd.DataFrame pandas dataframe of original data with imputation flags flag : str/list - the imputation flag column/s. Single string if one column, list of strings for multiple columns. + the imputation flag column/s. Single string if one column, list of + strings for multiple columns. period : str column name containing time period strata : str @@ -58,9 +68,6 @@ def count_matches(df, flag, period, strata): ------- _type_ pandas dataframe: match counts for each flag column. - """ - + """ + return df.groupby([strata, period])[flag].agg("sum").reset_index() - - - \ No newline at end of file diff --git a/tests/test_data_matched_pair/count_matches_expected_output.csv b/tests/test_data_matched_pair/count_matches_expected_output.csv index 54b7e5de..01e74b61 100755 --- a/tests/test_data_matched_pair/count_matches_expected_output.csv +++ b/tests/test_data_matched_pair/count_matches_expected_output.csv @@ -2,4 +2,4 @@ group,period,flag_1,flag_2 1,202401,1,0 1,202402,0,2 2,202401,2,1 -2,202402,1,1 \ No newline at end of file +2,202402,1,1 diff --git a/tests/test_data_matched_pair/count_matches_input.csv b/tests/test_data_matched_pair/count_matches_input.csv index c1743dbf..1b4d26ab 100755 --- a/tests/test_data_matched_pair/count_matches_input.csv +++ b/tests/test_data_matched_pair/count_matches_input.csv @@ -6,4 +6,4 @@ period,group,flag_1,flag_2 202401,2,1,TRUE 202401,2,1,FALSE 202402,2,0,FALSE -202402,2,1,TRUE \ No newline at end of file +202402,2,1,TRUE diff --git a/tests/test_data_matched_pair/flag_pairs_expected_output.csv b/tests/test_data_matched_pair/flag_pairs_expected_output.csv index 5a03317d..a7aeb212 100644 --- a/tests/test_data_matched_pair/flag_pairs_expected_output.csv +++ b/tests/test_data_matched_pair/flag_pairs_expected_output.csv @@ -6,4 +6,4 @@ reference,strata,period,target_variable,f_match,b_match 2,101,202402,250,True,True 2,101,202403,255,True,False 2,102,202404,260,False,True -2,102,202405,272,True,False \ No newline at end of file +2,102,202405,272,True,False diff --git a/tests/test_flag_and_count_matched_pairs.py b/tests/test_flag_and_count_matched_pairs.py index 312c5800..b353e11b 100644 --- a/tests/test_flag_and_count_matched_pairs.py +++ b/tests/test_flag_and_count_matched_pairs.py @@ -1,38 +1,53 @@ -import pandas as pd -import pytest - -from pandas.testing import assert_frame_equal from pathlib import Path -from src.flag_and_count_matched_pairs import flag_matched_pair, count_matches +import pandas as pd # noqa F401 +import pytest from helper_functions import load_and_format +from pandas.testing import assert_frame_equal + +from src.flag_and_count_matched_pairs import count_matches, flag_matched_pair + @pytest.fixture(scope="class") def match_test_data(): - return load_and_format(Path('tests')/'test_data_matched_pair/flag_pairs_expected_output.csv') + return load_and_format( + Path("tests") / "test_data_matched_pair/flag_pairs_expected_output.csv" + ) + @pytest.fixture(scope="class") def count_test_data(): - return load_and_format(Path('tests')/'test_data_matched_pair/count_matches_input.csv') - + return load_and_format( + Path("tests") / "test_data_matched_pair/count_matches_input.csv" + ) + + @pytest.fixture(scope="class") def count_expected_output(): - return load_and_format(Path('tests')/'test_data_matched_pair/count_matches_expected_output.csv') + return load_and_format( + Path("tests") / "test_data_matched_pair/count_matches_expected_output.csv" + ) + class TestMatchedPair: - def test_flag_matched_pair_forward(self, match_test_data): - expected_output = match_test_data.drop(['b_match'],axis = 1) - df_input = match_test_data[['reference', 'strata', 'period', 'target_variable']] - df_output = flag_matched_pair(df_input,'f','target_variable','period', 'reference', 'strata') - assert_frame_equal(df_output, expected_output) - - def test_flag_matched_pair_backward(self, match_test_data): - expected_output = match_test_data.drop(['f_match'],axis = 1) - df_input = match_test_data[['reference', 'strata', 'period', 'target_variable']] - df_output = flag_matched_pair(df_input,'b','target_variable','period', 'reference', 'strata') - assert_frame_equal(df_output, expected_output) - + def test_flag_matched_pair_forward(self, match_test_data): + expected_output = match_test_data.drop(["b_match"], axis=1) + df_input = match_test_data[["reference", "strata", "period", "target_variable"]] + df_output = flag_matched_pair( + df_input, "f", "target_variable", "period", "reference", "strata" + ) + assert_frame_equal(df_output, expected_output) + + def test_flag_matched_pair_backward(self, match_test_data): + expected_output = match_test_data.drop(["f_match"], axis=1) + df_input = match_test_data[["reference", "strata", "period", "target_variable"]] + df_output = flag_matched_pair( + df_input, "b", "target_variable", "period", "reference", "strata" + ) + assert_frame_equal(df_output, expected_output) + + class TestCountMatches: - def test_count_matches(self, count_test_data, count_expected_output): - output = count_matches(count_test_data, ["flag_1", "flag_2"], "period", "group") - assert_frame_equal(output, count_expected_output) \ No newline at end of file + def test_count_matches(self, count_test_data, count_expected_output): + output = count_matches(count_test_data, ["flag_1", "flag_2"], "period", "group") + assert_frame_equal(output, count_expected_output)