diff --git a/.gitignore b/.gitignore index 13c4a8c..f036118 100644 --- a/.gitignore +++ b/.gitignore @@ -167,3 +167,9 @@ logs/ # pyright config pyrightconfig.json + +# scratch +notebooks/scratch* + +# AcBM config +config/ diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 227e436..c1a3dc4 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -23,7 +23,7 @@ repos: - id: trailing-whitespace - repo: https://github.com/astral-sh/ruff-pre-commit - rev: "v0.2.0" + rev: "v0.7.0" hooks: # first, lint + autofix - id: ruff diff --git a/config/base.toml b/config/base.toml index 38a4f2f..9829be3 100644 --- a/config/base.toml +++ b/config/base.toml @@ -1,11 +1,21 @@ [parameters] seed = 0 region = "leeds" -number_of_households = 10000 +number_of_households = 5000 zone_id = "OA21CD" -travel_times = true # Only set to true if you have travel time matrix at the level specified in boundary_geography +travel_times = true # Only set to true if you have travel time matrix at the level specified in boundary_geography boundary_geography = "OA" +[matching] +required_columns = ["number_adults", "number_children"] +optional_columns = [ + "number_cars", + "num_pension_age", + "rural_urban_2_categories", + "employment_status", + "tenure_status", +] +n_matches = 10 [work_assignment] use_percentages = true diff --git a/config/base_500.toml b/config/base_500.toml deleted file mode 100644 index d9164a4..0000000 --- a/config/base_500.toml +++ /dev/null @@ -1,14 +0,0 @@ -[parameters] -seed = 0 -region = "leeds" -number_of_households = 500 -zone_id = "OA21CD" -travel_times = true # Only set to true if you have travel time matrix at the level specified in boundary_geography -boundary_geography = "OA" - - -[work_assignment] -use_percentages = true -weight_max_dev = 0.2 -weight_total_dev = 0.8 -max_zones = 8 diff --git a/config/base_5000.toml b/config/base_5000.toml deleted file mode 100644 index d4f0134..0000000 --- a/config/base_5000.toml +++ /dev/null @@ -1,13 +0,0 @@ -[parameters] -seed = 0 -region = "leeds" -number_of_households = 5000 -zone_id = "OA21CD" -travel_times = true # Only set to true if you have travel time matrix at the level specified in boundary_geography -boundary_geography = "OA" - -[work_assignment] -use_percentages = true -weight_max_dev = 0.2 -weight_total_dev = 0.8 -max_zones = 8 diff --git a/config/base_all.toml b/config/base_all.toml deleted file mode 100644 index bb1cc1e..0000000 --- a/config/base_all.toml +++ /dev/null @@ -1,13 +0,0 @@ -[parameters] -seed = 0 -region = "leeds" -zone_id = "OA21CD" -travel_times = true # Only set to true if you have travel time matrix at the level specified in boundary_geography -boundary_geography = "OA" - - -[work_assignment] -use_percentages = false -weight_max_dev = 0.0 -weight_total_dev = 1.0 -max_zones = 4 diff --git a/scripts/1_prep_synthpop.py b/scripts/1_prep_synthpop.py index 4d3d0e1..3df2dd8 100644 --- a/scripts/1_prep_synthpop.py +++ b/scripts/1_prep_synthpop.py @@ -27,35 +27,6 @@ def main(config_file): acbm.root_path / f"data/external/spc_output/{region}_people_hh.parquet" ) - # People and time-use data - # Subset of (non-time-use) features to include and unnest - # The features can be found here: https://github.com/alan-turing-institute/uatk-spc/blob/main/synthpop.proto - features = { - "health": [ - "bmi", - "has_cardiovascular_disease", - "has_diabetes", - "has_high_blood_pressure", - "self_assessed_health", - "life_satisfaction", - ], - "demographics": ["age_years", "ethnicity", "sex", "nssec8"], - "employment": ["sic1d2007", "sic2d2007", "pwkstat", "salary_yearly"], - } - - # build the table - spc_people_tu = ( - Builder(path, region, backend="polars", input_type="parquet") - .add_households() - .add_time_use_diaries(features, diary_type="weekday_diaries") - .build() - ) - - # save the output - spc_people_tu.write_parquet( - acbm.root_path / f"data/external/spc_output/{region}_people_tu.parquet" - ) - if __name__ == "__main__": main() diff --git a/scripts/2_match_households_and_individuals.py b/scripts/2_match_households_and_individuals.py index aaf766b..982e97c 100644 --- a/scripts/2_match_households_and_individuals.py +++ b/scripts/2_match_households_and_individuals.py @@ -8,6 +8,7 @@ # from joblib import Parallel, delayed # from tqdm import trange import acbm +from acbm.assigning.utils import cols_for_assignment_all from acbm.cli import acbm_cli from acbm.config import load_config from acbm.logger_config import matching_logger as logger @@ -39,14 +40,12 @@ def get_interim_path( # ### SPC - # useful variables - region = "leeds" - logger.info("Loading SPC data") # Read in the spc data (parquet format) spc = pd.read_parquet( - acbm.root_path / "data/external/spc_output/" f"{region}_people_hh.parquet" + acbm.root_path / "data/external/spc_output/" + f"{config.region}_people_hh.parquet" ) logger.info("Filtering SPC data to specific columns") @@ -687,194 +686,205 @@ def get_interim_path( ) # fill the NaNs with the original values # ## Step 3: Matching at Household Level - - logger.info("Categorical matching: MATCHING HOUSEHOLDS") - - # - # Now that we've prepared all the columns, we can start matching. - - # ### 3.1 Categorical matching - # - # We will match on (a subset of) the following columns: - # - # | Matching variable | NTS column | SPC column | - # | ------------------| ---------- | ---------- | - # | Household income | `HHIncome2002_BO2ID` | `salary_yearly_hh_cat` | - # | Number of adults | `HHoldNumAdults` | `num_adults` | - # | Number of children | `HHoldNumChildren` | `num_children` | - # | Employment status | `HHoldEmploy_B01ID` | `pwkstat_NTS_match` | - # | Car ownership | `NumCar_SPC_match` | `num_cars` | - # | Type of tenancy | `tenure_nts_for_matching` | `tenure_spc_for_matching` | - # | Rural/Urban Classification | `Settlement2011EW_B03ID` | `Settlement2011EW_B03ID_spc_CD` | - - # Prepare SPC df for matching - - # Select multiple columns - spc_matching = spc_edited[ - [ - "hid", - "salary_yearly_hh_cat", - "num_adults", - "num_children", - "num_pension_age", - "pwkstat_NTS_match", - "num_cars", - "tenure_spc_for_matching", - "Settlement2011EW_B03ID_spc_CD", - "Settlement2011EW_B04ID_spc_CD", + # TODO: remove once refactored into two scripts + load_households = False + if not load_households: + logger.info("Categorical matching: MATCHING HOUSEHOLDS") + + # + # Now that we've prepared all the columns, we can start matching. + + # ### 3.1 Categorical matching + # + # We will match on (a subset of) the following columns: + # + # | Matching variable | NTS column | SPC column | + # | ------------------| ---------- | ---------- | + # | Household income | `HHIncome2002_BO2ID` | `salary_yearly_hh_cat` | + # | Number of adults | `HHoldNumAdults` | `num_adults` | + # | Number of children | `HHoldNumChildren` | `num_children` | + # | Employment status | `HHoldEmploy_B01ID` | `pwkstat_NTS_match` | + # | Car ownership | `NumCar_SPC_match` | `num_cars` | + # | Type of tenancy | `tenure_nts_for_matching` | `tenure_spc_for_matching` | + # | Rural/Urban Classification | `Settlement2011EW_B03ID` | `Settlement2011EW_B03ID_spc_CD` | + + # Prepare SPC df for matching + + # Select multiple columns + spc_matching = spc_edited[ + [ + "hid", + "salary_yearly_hh_cat", + "num_adults", + "num_children", + "num_pension_age", + "pwkstat_NTS_match", + "num_cars", + "tenure_spc_for_matching", + "Settlement2011EW_B03ID_spc_CD", + "Settlement2011EW_B04ID_spc_CD", + ] ] - ] - - # edit the df so that we have one row per hid - spc_matching = spc_matching.drop_duplicates(subset="hid") - spc_matching.head(10) + # edit the df so that we have one row per hid + spc_matching = spc_matching.drop_duplicates(subset="hid") + + spc_matching.head(10) + + # Prepare NTS df for matching + + nts_matching = nts_households[ + [ + "HouseholdID", + "HHIncome2002_B02ID", + "HHoldNumAdults", + "HHoldNumChildren", + "num_pension_age_nts", + "HHoldEmploy_B01ID", + "NumCar_SPC_match", + "tenure_nts_for_matching", + "Settlement2011EW_B03ID", + "Settlement2011EW_B04ID", + ] + ] - # Prepare NTS df for matching + # Dictionary of matching columns. We extract column names from this dictioary when matching on a subset of the columns - nts_matching = nts_households[ - [ - "HouseholdID", - "HHIncome2002_B02ID", - "HHoldNumAdults", - "HHoldNumChildren", - "num_pension_age_nts", - "HHoldEmploy_B01ID", - "NumCar_SPC_match", - "tenure_nts_for_matching", - "Settlement2011EW_B03ID", - "Settlement2011EW_B04ID", + # column_names (keys) for the dictionary + matching_ids = [ + "household_id", + "yearly_income", + "number_adults", + "number_children", + "num_pension_age", + "employment_status", + "number_cars", + "tenure_status", + "rural_urban_2_categories", + "rural_urban_4_categories", ] - ] - # Dictionary of matching columns. We extract column names from this dictioary when matching on a subset of the columns - - # column_names (keys) for the dictionary - matching_ids = [ - "household_id", - "yearly_income", - "number_adults", - "number_children", - "num_pension_age", - "employment_status", - "number_cars", - "tenure_status", - "rural_urban_2_categories", - "rural_urban_4_categories", - ] + # Dict with value qual to a list with spc_matching and nts_matching column names + matching_dfs_dict = { + column_name: [spc_value, nts_value] + for column_name, spc_value, nts_value in zip( + matching_ids, spc_matching, nts_matching + ) + } - # Dict with value qual to a list with spc_matching and nts_matching column names - matching_dfs_dict = { - column_name: [spc_value, nts_value] - for column_name, spc_value, nts_value in zip( - matching_ids, spc_matching, nts_matching + # We match iteratively on a subset of columns. We start with all columns, and then remove + # one of the optionals columns at a time (relaxing the condition). Once a household has over n + # matches, we stop matching it to more matches. We continue until all optional columns are removed + matcher_exact = MatcherExact( + df_pop=spc_matching, + df_pop_id="hid", + df_sample=nts_matching, + df_sample_id="HouseholdID", + matching_dict=matching_dfs_dict, + fixed_cols=list(config.matching.required_columns), + optional_cols=list(config.matching.optional_columns), + n_matches=config.matching.n_matches, + chunk_size=config.matching.chunk_size, + show_progress=True, ) - } - - # We match iteratively on a subset of columns. We start with all columns, and then remove - # one of the optionals columns at a time (relaxing the condition). Once a household has over n - # matches, we stop matching it to more matches. We continue until all optional columns are removed - - # Define required columns for matching - required_columns = [ - "number_adults", - "number_children", - ] - # Define optional columns in order of importance (most to least important) - optional_columns = [ - "number_cars", - "num_pension_age", - "rural_urban_2_categories", - "employment_status", - "tenure_status", - ] - - matcher_exact = MatcherExact( - df_pop=spc_matching, - df_pop_id="hid", - df_sample=nts_matching, - df_sample_id="HouseholdID", - matching_dict=matching_dfs_dict, - fixed_cols=required_columns, - optional_cols=optional_columns, - n_matches=10, - chunk_size=50000, - show_progress=True, - ) - - # Match + # Match - matches_hh_level = matcher_exact.iterative_match_categorical() + matches_hh_level = matcher_exact.iterative_match_categorical() - # Number of unmatched households + # Number of unmatched households - # no. of keys where value is na - na_count = sum([1 for v in matches_hh_level.values() if pd.isna(v).all()]) - - logger.info(f"Categorical matching: {na_count} households in the SPC had no match") - logger.info( - f"{round((na_count / len(matches_hh_level)) * 100, 1)}% of households in the SPC had no match" - ) + # no. of keys where value is na + na_count = sum([1 for v in matches_hh_level.values() if pd.isna(v).all()]) - ## add matches_hh_level as a column in spc_edited - spc_edited["nts_hh_id"] = spc_edited["hid"].map(matches_hh_level) + logger.info( + f"Categorical matching: {na_count} households in the SPC had no match" + ) + logger.info( + f"{round((na_count / len(matches_hh_level)) * 100, 1)}% of households in the SPC had no match" + ) - # ### Random Sampling from matched households + # ### Random Sampling from matched households - logger.info("Categorical matching: Randomly choosing one match per household") - # - # In categorical matching, many households in the SPC are matched to more than 1 household in the NTS. Which household to choose? We do random sampling + logger.info("Categorical matching: Randomly choosing one match per household") + # + # In categorical matching, many households in the SPC are matched to more than 1 household in the NTS. Which household to choose? We do random sampling - # for each key in the dictionary, sample 1 of the values associated with it and store it in a new dictionary + # for each key in the dictionary, sample 1 of the values associated with it and store it in a new dictionary - """ - - iterate over each key-value pair in the matches_hh_result dictionary. - - For each key-value pair, use np.random.choice(value) to randomly select - one item from the list of values associated with the current key. - - create a new dictionary hid_to_HouseholdID_sample where each key from the - original dictionary is associated with one randomly selected value from the - original list of values. + """ + - iterate over each key-value pair in the matches_hh_result dictionary. + - For each key-value pair, use np.random.choice(value) to randomly select + one item from the list of values associated with the current key. + - create a new dictionary hid_to_HouseholdID_sample where each key from the + original dictionary is associated with one randomly selected value from the + original list of values. - """ - # Randomly sample one match per household if it has one match or more - matches_hh_level_sample = { - key: np.random.choice(value) - for key, value in matches_hh_level.items() - if value - and not pd.isna( - np.random.choice(value) - ) # Ensure the value list is not empty and the selected match is not NaN - } - - # Multiple matches in case we want to try stochastic runs - - # Same logic as above, but repeat it multiple times and store each result as a separate dictionary in a list - matches_hh_level_sample_list = [ - { + """ + # Randomly sample one match per household if it has one match or more + matches_hh_level_sample = { key: np.random.choice(value) for key, value in matches_hh_level.items() - if value and not pd.isna(np.random.choice(value)) + if value + and not pd.isna( + np.random.choice(value) + ) # Ensure the value list is not empty and the selected match is not NaN } - for i in range(25) # Repeat the process 25 times - ] - logger.info("Categorical matching: Random sampling complete") + # Multiple matches in case we want to try stochastic runs - # Save results - logger.info("Categorical matching: Saving results") - # random sample - with open( - get_interim_path("matches_hh_level_categorical_random_sample.pkl"), "wb" - ) as f: - pkl.dump(matches_hh_level_sample, f) + # Same logic as above, but repeat it multiple times and store each result as a separate dictionary in a list + matches_hh_level_sample_list = [ + { + key: np.random.choice(value) + for key, value in matches_hh_level.items() + if value and not pd.isna(np.random.choice(value)) + } + for i in range(25) # Repeat the process 25 times + ] - # multiple random samples - with open( - get_interim_path("matches_hh_level_categorical_random_sample_multiple.pkl"), - "wb", - ) as f: - pkl.dump(matches_hh_level_sample_list, f) + logger.info("Categorical matching: Random sampling complete") + + # Save results + logger.info("Categorical matching: Saving results") + + # matching results + with open(get_interim_path("matches_hh_level_categorical.pkl"), "wb") as f: + pkl.dump(matches_hh_level, f) + + # random sample + with open( + get_interim_path("matches_hh_level_categorical_random_sample.pkl"), "wb" + ) as f: + pkl.dump(matches_hh_level_sample, f) + + # multiple random samples + with open( + get_interim_path("matches_hh_level_categorical_random_sample_multiple.pkl"), + "wb", + ) as f: + pkl.dump(matches_hh_level_sample_list, f) + else: + logger.info("Categorical matching: loading matched households") + # Load matching result + with open( + get_interim_path("matches_hh_level_categorical_random_sample.pkl"), "rb" + ) as f: + matches_hh_level_sample = pkl.load(f) + + # multiple random samples + with open( + get_interim_path("matches_hh_level_categorical_random_sample_multiple.pkl"), + "rb", + ) as f: + matches_hh_level_sample_list = pkl.load(f) + + # TODO: check if this: + # - column is required and possibly update other scripts to add this column in-memory since it is large + # - or can use the single sample hh for the new column + # For now, updated to use the sample dictionary + ## add matches_hh_level as a column in spc_edited + spc_edited["nts_hh_id"] = spc_edited["hid"].map(matches_hh_level_sample) # Do the same at the df level. Add nts_hh_id_sample column to the spc df @@ -894,8 +904,6 @@ def get_interim_path( # # - logger.info("Statistical matching: MATCHING INDIVIDUALS") - # Create an 'age' column in the SPC that matches the NTS categories # create a dictionary for reference on how the labels for "Age_B04ID" match the actual age brackets @@ -930,17 +938,33 @@ def get_interim_path( columns={"Age_B04ID": "age_group", "Sex_B01ID": "sex"}, inplace=True ) - # PSM matching using internal match_individuals function + # TODO: remove once refactored into two scripts + load_individuals = False + if not load_individuals: + logger.info("Statistical matching: MATCHING INDIVIDUALS") + + # PSM matching using internal match_individuals function + matches_ind = match_individuals( + df1=spc_edited, + df2=nts_individuals, + matching_columns=["age_group", "sex"], + df1_id="hid", + df2_id="HouseholdID", + matches_hh=matches_hh_level_sample, + show_progress=True, + ) - matches_ind = match_individuals( - df1=spc_edited, - df2=nts_individuals, - matching_columns=["age_group", "sex"], - df1_id="hid", - df2_id="HouseholdID", - matches_hh=matches_hh_level_sample, - show_progress=True, - ) + # save random sample + with open( + get_interim_path("matches_ind_level_categorical_random_sample.pkl"), "wb" + ) as f: + pkl.dump(matches_ind, f) + else: + logger.info("Statistical matching: loading matched individuals") + with open( + get_interim_path("matches_ind_level_categorical_random_sample.pkl"), "rb" + ) as f: + matches_ind = pkl.load(f) # Add matches_ind values to spc_edited using map spc_edited["nts_ind_id"] = spc_edited.index.map(matches_ind) @@ -952,12 +976,6 @@ def get_interim_path( logger.info("Statistical matching: Matching complete") - # save random sample - with open( - get_interim_path("matches_ind_level_categorical_random_sample.pkl"), "wb" - ) as f: - pkl.dump(matches_ind, f) - # ### Match on multiple samples # logger.info("Statistical matching: Matching on multiple samples") @@ -1101,9 +1119,20 @@ def get_interim_path( # convert the nts_ind_id column to int for merging spc_edited_copy["nts_ind_id"] = spc_edited_copy["nts_ind_id"].astype(int) + # Add output columns required for assignment scripts + spc_output_cols = [ + col for col in spc_edited_copy.columns if col in cols_for_assignment_all() + ] + nts_output_cols = [ + col for col in nts_trips.columns if col in cols_for_assignment_all() + ] + ["IndividualID"] + # merge the copy with nts_trips using IndividualID - spc_edited_copy = spc_edited_copy.merge( - nts_trips, left_on="nts_ind_id", right_on="IndividualID", how="left" + spc_edited_copy = spc_edited_copy[spc_output_cols].merge( + nts_trips[nts_output_cols], + left_on="nts_ind_id", + right_on="IndividualID", + how="left", ) # save the file as a parquet file diff --git a/scripts/3.2.2_assign_primary_zone_work.py b/scripts/3.2.2_assign_primary_zone_work.py index 82ee9ee..1c1da11 100644 --- a/scripts/3.2.2_assign_primary_zone_work.py +++ b/scripts/3.2.2_assign_primary_zone_work.py @@ -62,8 +62,12 @@ def main(config_file): # Commuting matrices (from 2021 census) - # TODO: consider making this configurable - commute_level = config.boundary_geography # "OA" or "MSOA" data + # "OA" or "MSOA" data: set as config.boundary_geography if not passed + commute_level = ( + config.boundary_geography + if config.work_assignment.commute_level is None + else config.work_assignment.commute_level + ) logger.info(f"Loading commuting matrices at {commute_level} level") @@ -256,20 +260,20 @@ def main(config_file): workzone_assignment_opt["pct_of_o_total_actual"] = workzone_assignment_opt.groupby( "origin_zone" )["demand_actual"].transform(lambda x: (x / x.sum()) * 100) - workzone_assignment_opt[ - "pct_of_o_total_assigned" - ] = workzone_assignment_opt.groupby("origin_zone")["demand_assigned"].transform( - lambda x: (x / x.sum()) * 100 + workzone_assignment_opt["pct_of_o_total_assigned"] = ( + workzone_assignment_opt.groupby( + "origin_zone" + )["demand_assigned"].transform(lambda x: (x / x.sum()) * 100) ) # (3) For each OD pair, demand as % of total demand to each destination workzone_assignment_opt["pct_of_d_total_actual"] = workzone_assignment_opt.groupby( "assigned_zone" )["demand_actual"].transform(lambda x: (x / x.sum()) * 100) - workzone_assignment_opt[ - "pct_of_d_total_assigned" - ] = workzone_assignment_opt.groupby("assigned_zone")["demand_assigned"].transform( - lambda x: (x / x.sum()) * 100 + workzone_assignment_opt["pct_of_d_total_assigned"] = ( + workzone_assignment_opt.groupby( + "assigned_zone" + )["demand_assigned"].transform(lambda x: (x / x.sum()) * 100) ) # Define the output file path diff --git a/scripts/3.2.3_assign_secondary_zone.py b/scripts/3.2.3_assign_secondary_zone.py index d649a70..dfc7418 100644 --- a/scripts/3.2.3_assign_secondary_zone.py +++ b/scripts/3.2.3_assign_secondary_zone.py @@ -224,6 +224,7 @@ def merge_columns_from_other(df: pd.DataFrame, other: pd.DataFrame) -> pd.DataFr "id", "household", "nts_ind_id", + # TODO: check if this column is required "nts_hh_id", "age_years", "oact", diff --git a/src/acbm/__init__.py b/src/acbm/__init__.py index 0171ca8..d630574 100644 --- a/src/acbm/__init__.py +++ b/src/acbm/__init__.py @@ -1,6 +1,7 @@ """ acbm: A package to create activity-based models (for transport demand modelling) """ + from __future__ import annotations import os diff --git a/src/acbm/assigning/utils.py b/src/acbm/assigning/utils.py index 24e3561..1fb3ddc 100644 --- a/src/acbm/assigning/utils.py +++ b/src/acbm/assigning/utils.py @@ -14,6 +14,7 @@ def cols_for_assignment_all() -> list[str]: "household", "oact", "nts_ind_id", + # TODO: check if this column is required "nts_hh_id", "age_years", "TripDisIncSW", diff --git a/src/acbm/config.py b/src/acbm/config.py index 1ae03fd..9ea592b 100644 --- a/src/acbm/config.py +++ b/src/acbm/config.py @@ -17,12 +17,21 @@ class Parameters(BaseModel): boundary_geography: str +@dataclass(frozen=True) +class MatchingParams(BaseModel): + required_columns: list[str] + optional_columns: list[str] + n_matches: int | None = None + chunk_size: int = 50_000 + + @dataclass(frozen=True) class WorkAssignmentParams(BaseModel): use_percentages: bool weight_max_dev: float weight_total_dev: float max_zones: int + commute_level: str | None class Config(BaseModel): @@ -30,6 +39,7 @@ class Config(BaseModel): work_assignment: WorkAssignmentParams = Field( description="Config: parameters for work assignment." ) + matching: MatchingParams = Field(description="Config: parameters for matching.") @property def seed(self) -> int: @@ -62,7 +72,7 @@ def init_rng(self): random.seed(self.seed) except Exception as err: msg = f"config does not provide a rng seed with err: {err}" - ValueError(msg) + raise ValueError(msg) from err def load_config(filepath: str | Path) -> Config: diff --git a/src/acbm/matching.py b/src/acbm/matching.py index 9508fa7..0be0dfd 100644 --- a/src/acbm/matching.py +++ b/src/acbm/matching.py @@ -18,7 +18,7 @@ class MatcherExact: matching_dict: Dict[str, List[str]] fixed_cols: List[str] optional_cols: List[str] - n_matches: int = 5 + n_matches: int | None = 10 chunk_size: int = 50000 show_progress: bool = True matched_dict: Dict[str, List[str]] = field( @@ -147,11 +147,15 @@ def iterative_match_categorical(self) -> Dict[str, List[str]]: self.matched_dict[pop_id].extend(unique_sample_ids) self.match_count[pop_id] += len(unique_sample_ids) - matched_ids = [ - pop_id - for pop_id, count in self.match_count.items() - if count >= self.n_matches - ] + matched_ids = ( + [ + pop_id + for pop_id, count in self.match_count.items() + if count >= self.n_matches + ] + if self.n_matches is not None + else [] + ) self.remaining_df_pop = self.remaining_df_pop[ ~self.remaining_df_pop[self.df_pop_id].isin(matched_ids) ] @@ -264,13 +268,19 @@ def match_individuals( # Remove all unmateched households matches_hh = {key: value for key, value in matches_hh.items() if not pd.isna(value)} - # loop over all rows in the matches_hh dictionary - for i, (key, value) in enumerate(matches_hh.items(), 1): - # Get the rows in df1 and df2 that correspond to the matched hids - rows_df1 = df1[df1[df1_id] == key] + # loop over all groups of df1_id + # note: for large populations looping through the groups (keys) of the + # large dataframe (assumed to be df1) is more efficient than looping + # over keys and subsetting on a key in each iteration. + for i, (key, rows_df1) in enumerate(df1.groupby(df1_id), 1): + try: + value = matches_hh[key] + except Exception: + # Continue if key not in matches_hh + continue rows_df2 = df2[df2[df2_id] == int(value)] - if show_progress: + if show_progress and i % 100 == 0: # Print the iteration number and the number of keys in the dict print(f"Matching for household {i} out of: {len(matches_hh)}") diff --git a/tests/test_matching.py b/tests/test_matching.py index fcd17f1..6aa7727 100644 --- a/tests/test_matching.py +++ b/tests/test_matching.py @@ -4,7 +4,7 @@ from acbm.matching import MatcherExact, match_psm # noqa: F401 -@pytest.fixture() +@pytest.fixture def setup_data(): df_pop = pd.DataFrame( {