Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor to improve (but not entirely fix) type hints #58

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion policyengine/outputs/household/single/net_income.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
def net_income(simulation):
return simulation.selected.calculate("household_net_income").sum()
return simulation.selected_sim.calculate("household_net_income").sum()
2 changes: 1 addition & 1 deletion policyengine/outputs/macro/single/gov/balance.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@


def balance(simulation: Simulation) -> dict:
sim = simulation.selected
sim = simulation.selected_sim
if simulation.country == "uk":
total_tax = sim.calculate("gov_tax").sum()
total_spending = sim.calculate("gov_spending").sum()
Expand Down
2 changes: 1 addition & 1 deletion policyengine/outputs/macro/single/gov/budget_window.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@


def budget_window(simulation: Simulation, count_years: int = None) -> dict:
sim = simulation.selected
sim = simulation.selected_sim
current_year = simulation.time_period
if count_years is not None:
years = list(range(current_year, current_year + count_years))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def parliamentary_constituencies(

result = {}

sim = simulation.selected
sim = simulation.selected_sim
original_hh_weight = sim.calculate("household_weight").values

for constituency_id in range(weights.shape[0]):
Expand All @@ -63,7 +63,7 @@ def parliamentary_constituencies(
sim.get_holder("benunit_weight").delete_arrays(
sim.default_calculation_period
)
calculation_result = metric(simulation.selected)
calculation_result = metric(simulation.selected_sim)
code = constituency_names.code.iloc[constituency_id]
result[constituency_names.set_index("code").loc[code]["name"]] = (
calculation_result
Expand Down
2 changes: 1 addition & 1 deletion policyengine/outputs/macro/single/gov/programs.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ class UKPrograms:
def programs(simulation: Simulation) -> dict:
if simulation.country == "uk":
return {
program.name: simulation.selected.calculate(
program.name: simulation.selected_sim.calculate(
program.name, map_to="household"
).sum()
* (1 if program.is_positive else -1)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@


def demographics(simulation: Simulation) -> dict:
sim = simulation.selected
sim = simulation.selected_sim
household_count_people = (
sim.calculate("household_count_people").astype(int).tolist()
)
Expand Down
2 changes: 1 addition & 1 deletion policyengine/outputs/macro/single/household/finance.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@


def finance(simulation: Simulation) -> dict:
sim = simulation.selected
sim = simulation.selected_sim

total_net_income = sim.calculate("household_net_income").sum()
employment_income_hh = (
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ def income_distribution(simulation: Simulation, chart: bool = False) -> dict:


def income_distribution_chart(simulation: Simulation) -> go.Figure:
income = simulation.baseline.calculate("household_net_income")
income = simulation.baseline_sim.calculate("household_net_income")
income_upper = income.quantile(0.9)
BAND_SIZE = 5_000
lower_income_bands = []
Expand Down
4 changes: 2 additions & 2 deletions policyengine/outputs/macro/single/household/inequality.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@


def inequality(simulation: Simulation) -> dict:
personal_hh_equiv_income = simulation.selected.calculate(
personal_hh_equiv_income = simulation.selected_sim.calculate(
"equiv_household_net_income"
)
personal_hh_equiv_income[personal_hh_equiv_income < 0] = 0
household_count_people = simulation.selected.calculate(
household_count_people = simulation.selected_sim.calculate(
"household_count_people"
).values
personal_hh_equiv_income.weights *= household_count_people
Expand Down
4 changes: 2 additions & 2 deletions policyengine/outputs/macro/single/household/labor_supply.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@


def labor_supply(simulation: Simulation) -> dict:
sim = simulation.selected
sim = simulation.selected_sim
household_count_people = sim.calculate("household_count_people").values
result = {
"substitution_lsr": 0,
Expand Down Expand Up @@ -54,7 +54,7 @@ def labor_supply(simulation: Simulation) -> dict:


def has_behavioral_response(simulation):
sim = simulation.selected
sim = simulation.selected_sim
return (
"employment_income_behavioral_response"
in sim.tax_benefit_system.variables
Expand Down
101 changes: 50 additions & 51 deletions policyengine/simulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,46 +14,45 @@
from typing import Tuple, Any
from policyengine.constants import *


class Simulation:
"""The top-level class through which all PE usage is carried out."""

country: str
"""The country for which the simulation is being run."""
scope: str
"""The type of simulation being run (macro or household)."""
data: str
data: dict | str | Dataset
"""The dataset being used for the simulation."""
time_period: str
time_period: str | None
"""The time period for the simulation. Years are applicable."""
baseline: dict
baseline: dict | None
nikhilwoodruff marked this conversation as resolved.
Show resolved Hide resolved
"""The baseline simulation inputs."""
reform: dict
reform: dict | None
"""The reform simulation inputs."""
options: dict
"""Dynamic options for the simulation type."""

comparison: bool
"""Whether we are comparing two simulations, or analysing a single one."""
baseline: CountrySimulation = None
baseline_sim: CountrySimulation
"""The tax-benefit simulation for the baseline scenario."""
reformed: CountrySimulation = None
"""The tax-benefit simulation for the reformed scenario."""
selected: CountryMicrosimulation = None
"""The selected simulation for the current calculation."""
reformed_sim: CountrySimulation | None = None
"""The tax-benefit simulation for the reformed scenario. None if no reform has been configured"""
selected_sim: CountrySimulation | None = None
"""The selected simulation for the current calculation. None if not a reform."""
verbose: bool = False
"""Whether to print out progress messages."""

def __init__(
self,
country: str,
scope: str,
data: str = None,
time_period: str = None,
reform: dict = None,
baseline: dict = None,
data: str | dict | None = None,
time_period: str | None = None,
reform: dict | None = None,
baseline: dict | None = None,
verbose: bool = False,
options: dict = None,
options: dict | None = None,
):
"""Initialise the simulation with the given parameters.

Expand All @@ -70,17 +69,6 @@ def __init__(
self.time_period = time_period
self.verbose = verbose
self.options = options or {}

if isinstance(reform, dict):
reform = Reform.from_dict(reform, country_id=country)
elif isinstance(reform, int):
reform = Reform.from_api(reform, country_id=country)

if isinstance(baseline, dict):
baseline = Reform.from_dict(baseline, country_id=country)
elif isinstance(baseline, int):
baseline = Reform.from_api(baseline, country_id=country)

self.baseline = baseline
self.reform = reform

Expand All @@ -90,17 +78,14 @@ def __init__(

self._initialise_simulations()

def _set_dataset(self, dataset: str):
def _set_dataset(self, dataset: str | dict | None):
if isinstance(dataset, dict):
self.data = dataset
return

self.data = DEFAULT_DATASETS[self.country]
if dataset in DATASETS[self.country]:
self.data = DATASETS[self.country][dataset]
elif dataset is None:
self.data = DEFAULT_DATASETS[self.country]
else:
self.data = dataset

# Short-term hacky fix: handle legacy 'array' datasets that don't specify the year for each variable: we should transition these to variable/period/value format.
# But they're used frequently for now, and we need backwards compatibility.
Expand All @@ -119,7 +104,7 @@ def _set_dataset(self, dataset: str):
local_folder=None,
version=version,
)
self.data = Dataset.from_file(self.data, 2023)
self.data = Dataset.from_file(self.data, "2023")

def calculate(self, output: str, force: bool = False, **kwargs) -> Any:
"""Calculate the given output (path).
Expand Down Expand Up @@ -178,6 +163,8 @@ def _get_outputs(self) -> Tuple[dict, dict]:
for output in Path(__file__).parent.glob("outputs/**/*.py"):
module_name = output.stem
spec = importlib.util.spec_from_file_location(module_name, output)
if spec is None:
raise RuntimeError(f"Expected to load a spec from file '{output.absolute}'")
module = importlib.util.module_from_spec(spec)
relative_path = str(
output.relative_to(Path(__file__).parent / "outputs")
Expand All @@ -189,6 +176,8 @@ def _get_outputs(self) -> Tuple[dict, dict]:
# Don't load household modules for macro comparisons, etc.
continue

if spec.loader is None:
raise RuntimeError(f"Expected module from '{output.absolute}' to have a loader, but it does not")
nikhilwoodruff marked this conversation as resolved.
Show resolved Hide resolved
spec.loader.exec_module(module)

# Only import the function with the same name as the module, enforcing one function per file
Expand All @@ -204,11 +193,11 @@ def _get_outputs(self) -> Tuple[dict, dict]:
func = output_functions[key]

def passed_reform_simulation(func, is_reform):
def adjusted_func(simulation, **kwargs):
def adjusted_func(simulation:Simulation, **kwargs):
if is_reform:
simulation.selected = simulation.reformed
simulation.selected_sim = simulation.reformed_sim
else:
simulation.selected = simulation.baseline
simulation.selected_sim = simulation.baseline_sim
return func(simulation, **kwargs)

return adjusted_func
Expand Down Expand Up @@ -242,7 +231,15 @@ def adjusted_func(simulation, **kwargs):

return output_functions, outputs

def _to_reform(self, value: int | dict):
if isinstance(value, dict):
return Reform.from_dict(value, country_id = self.country)
return Reform.from_api(f"{value}", country_id = self.country)

def _initialise_simulations(self):
self._parsed_reform = self._to_reform(self.reform) if self.reform is not None else None
self._parsed_baseline = self._to_reform(self.baseline) if self.baseline is not None else None

macro = self.scope == "macro"
_simulation_type = {
"uk": {
Expand All @@ -254,37 +251,39 @@ def _initialise_simulations(self):
False: USSimulation,
},
}[self.country][macro]
self.baseline = _simulation_type(
self.baseline_sim = _simulation_type(
dataset=self.data if macro else None,
situation=self.data if not macro else None,
reform=self.baseline,
reform=self._parsed_baseline,
)
self.baseline.default_calculation_period = self.time_period

if self.time_period is not None:
self.baseline_sim.default_calculation_period = self.time_period

if "subsample" in self.options:
self.baseline = self.baseline.subsample(self.options["subsample"])
self.baseline_sim = self.baseline_sim.subsample(self.options["subsample"])

if "region" in self.options:
self.baseline = self._apply_region_to_simulation(
self.baseline, _simulation_type, self.options["region"]
if "region" in self.options and isinstance(self.baseline_sim, CountryMicrosimulation):
self.baseline_sim = self._apply_region_to_simulation(
self.baseline_sim, _simulation_type, self.options["region"]
)

if self.comparison:
self.reformed = _simulation_type(
self.reformed_sim = _simulation_type(
dataset=self.data if macro else None,
situation=self.data if not macro else None,
reform=self.reform,
reform=self._parsed_reform,
)
self.reformed.default_calculation_period = self.time_period
self.reformed_sim.default_calculation_period = self.time_period

if "subsample" in self.options:
self.reformed = self.reformed.subsample(
self.reformed_sim = self.reformed_sim.subsample(
self.options["subsample"]
)

if "region" in self.options:
self.reformed = self._apply_region_to_simulation(
self.reformed, _simulation_type, self.options["region"]
if "region" in self.options and isinstance(self.reformed_sim, CountryMicrosimulation):
self.reformed_sim = self._apply_region_to_simulation(
self.reformed_sim, _simulation_type, self.options["region"]
)

def _apply_region_to_simulation(
Expand All @@ -301,12 +300,12 @@ def _apply_region_to_simulation(
if region == "city/nyc":
in_nyc = simulation.calculate("in_nyc", map_to="person").values
simulation = simulation_type(
dataset=df[in_nyc], reform=self.reform
dataset=df[in_nyc], reform=self._parsed_reform
)
elif "state/" in region:
state = region.split("/")[1]
simulation = simulation_type(
dataset=df[state_code == state.upper()], reform=self.reform
dataset=df[state_code == state.upper()], reform=self._parsed_reform
)

return simulation
2 changes: 1 addition & 1 deletion policyengine/utils/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@


def download(
repo: str, repo_filename: str, local_folder: str, version: str = None
repo: str, repo_filename: str, local_folder: str | None = None, version: str | None = None
):
token = os.environ.get("HUGGING_FACE_TOKEN")
if token is None:
Expand Down
Loading