Skip to content

Commit

Permalink
ran pre-commits
Browse files Browse the repository at this point in the history
  • Loading branch information
robertswh committed Jun 10, 2024
1 parent 47f02dd commit ad64d1a
Show file tree
Hide file tree
Showing 5 changed files with 68 additions and 46 deletions.
45 changes: 26 additions & 19 deletions src/flag_and_count_matched_pairs.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
import pandas as pd
import numpy as np
import numpy as np # noqa F401
import pandas as pd # noqa F401

def flag_matched_pair(df, forward_or_backward, target, period, reference, strata, time_difference=1):

def flag_matched_pair(
df, forward_or_backward, target, period, reference, strata, time_difference=1
):
"""
function to flag matched pairs using the shift method
Expand All @@ -25,20 +28,26 @@ def flag_matched_pair(df, forward_or_backward, target, period, reference, strata
Returns
-------
_type_
two pandas dataframes: the main dataframe with column added flagging forward matched pairs and
two pandas dataframes: the main dataframe with column added flagging
forward matched pairs and
predictive target variable data column
"""
df = df.sort_values(by = [reference, period])
"""

df = df.sort_values(by=[reference, period])

if forward_or_backward == 'b':
if forward_or_backward == "b":
time_difference = -time_difference

df[forward_or_backward+"_match"] = df.groupby([strata, reference]).shift(time_difference)[target].notnull().mul(df[target].notnull())


df[forward_or_backward + "_match"] = (
df.groupby([strata, reference])
.shift(time_difference)[target]
.notnull()
.mul(df[target].notnull())
)

return df


def count_matches(df, flag, period, strata):
"""
function to flag matched pairs using the shift method
Expand All @@ -48,7 +57,8 @@ def count_matches(df, flag, period, strata):
df : pd.DataFrame
pandas dataframe of original data with imputation flags
flag : str/list
the imputation flag column/s. Single string if one column, list of strings for multiple columns.
the imputation flag column/s. Single string if one column, list of
strings for multiple columns.
period : str
column name containing time period
strata : str
Expand All @@ -58,9 +68,6 @@ def count_matches(df, flag, period, strata):
-------
_type_
pandas dataframe: match counts for each flag column.
"""
"""

return df.groupby([strata, period])[flag].agg("sum").reset_index()



Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@ group,period,flag_1,flag_2
1,202401,1,0
1,202402,0,2
2,202401,2,1
2,202402,1,1
2,202402,1,1
2 changes: 1 addition & 1 deletion tests/test_data_matched_pair/count_matches_input.csv
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,4 @@ period,group,flag_1,flag_2
202401,2,1,TRUE
202401,2,1,FALSE
202402,2,0,FALSE
202402,2,1,TRUE
202402,2,1,TRUE
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,4 @@ reference,strata,period,target_variable,f_match,b_match
2,101,202402,250,True,True
2,101,202403,255,True,False
2,102,202404,260,False,True
2,102,202405,272,True,False
2,102,202405,272,True,False
63 changes: 39 additions & 24 deletions tests/test_flag_and_count_matched_pairs.py
Original file line number Diff line number Diff line change
@@ -1,38 +1,53 @@
import pandas as pd
import pytest

from pandas.testing import assert_frame_equal
from pathlib import Path

from src.flag_and_count_matched_pairs import flag_matched_pair, count_matches
import pandas as pd # noqa F401
import pytest
from helper_functions import load_and_format
from pandas.testing import assert_frame_equal

from src.flag_and_count_matched_pairs import count_matches, flag_matched_pair


@pytest.fixture(scope="class")
def match_test_data():
return load_and_format(Path('tests')/'test_data_matched_pair/flag_pairs_expected_output.csv')
return load_and_format(
Path("tests") / "test_data_matched_pair/flag_pairs_expected_output.csv"
)


@pytest.fixture(scope="class")
def count_test_data():
return load_and_format(Path('tests')/'test_data_matched_pair/count_matches_input.csv')

return load_and_format(
Path("tests") / "test_data_matched_pair/count_matches_input.csv"
)


@pytest.fixture(scope="class")
def count_expected_output():
return load_and_format(Path('tests')/'test_data_matched_pair/count_matches_expected_output.csv')
return load_and_format(
Path("tests") / "test_data_matched_pair/count_matches_expected_output.csv"
)


class TestMatchedPair:
def test_flag_matched_pair_forward(self, match_test_data):
expected_output = match_test_data.drop(['b_match'],axis = 1)
df_input = match_test_data[['reference', 'strata', 'period', 'target_variable']]
df_output = flag_matched_pair(df_input,'f','target_variable','period', 'reference', 'strata')
assert_frame_equal(df_output, expected_output)

def test_flag_matched_pair_backward(self, match_test_data):
expected_output = match_test_data.drop(['f_match'],axis = 1)
df_input = match_test_data[['reference', 'strata', 'period', 'target_variable']]
df_output = flag_matched_pair(df_input,'b','target_variable','period', 'reference', 'strata')
assert_frame_equal(df_output, expected_output)

def test_flag_matched_pair_forward(self, match_test_data):
expected_output = match_test_data.drop(["b_match"], axis=1)
df_input = match_test_data[["reference", "strata", "period", "target_variable"]]
df_output = flag_matched_pair(
df_input, "f", "target_variable", "period", "reference", "strata"
)
assert_frame_equal(df_output, expected_output)

def test_flag_matched_pair_backward(self, match_test_data):
expected_output = match_test_data.drop(["f_match"], axis=1)
df_input = match_test_data[["reference", "strata", "period", "target_variable"]]
df_output = flag_matched_pair(
df_input, "b", "target_variable", "period", "reference", "strata"
)
assert_frame_equal(df_output, expected_output)


class TestCountMatches:
def test_count_matches(self, count_test_data, count_expected_output):
output = count_matches(count_test_data, ["flag_1", "flag_2"], "period", "group")
assert_frame_equal(output, count_expected_output)
def test_count_matches(self, count_test_data, count_expected_output):
output = count_matches(count_test_data, ["flag_1", "flag_2"], "period", "group")
assert_frame_equal(output, count_expected_output)

0 comments on commit ad64d1a

Please sign in to comment.