Skip to content

Commit

Permalink
Integrate comments
Browse files Browse the repository at this point in the history
  • Loading branch information
timmens committed Sep 21, 2024
1 parent b1264ec commit b4e7ac6
Show file tree
Hide file tree
Showing 5 changed files with 30 additions and 38 deletions.
2 changes: 1 addition & 1 deletion src/lcm/entry_point.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ def get_lcm_function(
elif targets == "solve_and_simulate":
_target = partial(simulate_model, solve_model=solve_model)

user_params = _mod.discrete_grid_converter.internal_to_params(_mod.params)
user_params = _mod.discrete_grid_converter.internal_params_to_params(_mod.params)
return cast(Callable, _target), user_params


Expand Down
46 changes: 21 additions & 25 deletions src/lcm/input_processing/discrete_grid_conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,66 +44,62 @@ def __post_init__(self) -> None:
"The keys of index_to_code and code_to_index must be the same."
)

def internal_to_params(self, params: ParamsDict) -> ParamsDict:
def internal_params_to_params(self, internal: ParamsDict) -> ParamsDict:
"""Convert parameters from internal to external representation.
If a state has been converted, the name of its corresponding next function must
be changed from `next___{var}_index__` to `next_{var}`.
"""
out = params.copy()
params = internal.copy()
for var in self.index_to_code:
old_name = f"next___{var}_index__"
if old_name in out:
out[f"next_{var}"] = out.pop(old_name)
return out
if old_name in params:
params[f"next_{var}"] = params.pop(old_name)
return params

def params_to_internal(self, params: ParamsDict) -> ParamsDict:
def params_to_internal_params(self, params: ParamsDict) -> ParamsDict:
"""Convert parameters from external to internal representation.
If a state has been converted, the name of its corresponding next function must
be changed from `next_{var}` to `next___{var}_index__`.
"""
out = params.copy()
internal = params.copy()
for var in self.index_to_code:
old_name = f"next_{var}"
if old_name in out:
out[f"next___{var}_index__"] = out.pop(old_name)
return out
if old_name in internal:
internal[f"next___{var}_index__"] = internal.pop(old_name)
return internal

def internal_to_discrete_vars(
self, variables: dict[str, Array]
) -> dict[str, Array]:
"""Convert states from internal to external representation.
def internal_vars_to_vars(self, internal: dict[str, Array]) -> dict[str, Array]:
"""Convert discrete variables from internal to external representation.
If a variable has been converted, the name of its corresponding index function
must be changed from `___{var}_index__` to `{var}`, and the values of the
variable must be converted from indices to codes.
"""
out = variables.copy()
variables = internal.copy()
for var, index_to_code in self.index_to_code.items():
old_name = f"__{var}_index__"
if old_name in variables:
out[var] = index_to_code(out.pop(old_name))
return out
if old_name in internal:
variables[var] = index_to_code(variables.pop(old_name))
return variables

def discrete_vars_to_internal(
self, variables: dict[str, Array]
) -> dict[str, Array]:
"""Convert states from external to internal representation.
def vars_to_internal_vars(self, variables: dict[str, Array]) -> dict[str, Array]:
"""Convert discrete variables from external to internal representation.
If a variable has been converted, the name of its corresponding index function
must be changed from `{var}` to `___{var}_index__`, and the values of the
variable must be converted from codes to indices.
"""
out = variables.copy()
internal = variables.copy()
for var, code_to_index in self.code_to_index.items():
if var in variables:
out[f"__{var}_index__"] = code_to_index(out.pop(var))
return out
internal[f"__{var}_index__"] = code_to_index(internal.pop(var))
return internal


def convert_arbitrary_codes_to_array_indices(
Expand Down
6 changes: 3 additions & 3 deletions src/lcm/simulate.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def simulate(
# will do it.
vf_arr_list = solve_model(params)

internal_params = discrete_grid_converter.params_to_internal(params)
internal_params = discrete_grid_converter.params_to_internal_params(params)

logger.info("Starting simulation")

Expand All @@ -97,7 +97,7 @@ def simulate(
sparse_choice_variables = model.variable_info.query("is_choice & is_sparse").index

# The following variables are updated during the forward simulation
states = discrete_grid_converter.discrete_vars_to_internal(initial_states)
states = discrete_grid_converter.vars_to_internal_vars(initial_states)
key = jax.random.PRNGKey(seed=seed)

# Forward simulation
Expand Down Expand Up @@ -372,7 +372,7 @@ def _process_simulated_data(results, discrete_grid_converter):
out = {key: jnp.concatenate(values) for key, values in dict_of_lists.items()}
out["_period"] = jnp.repeat(jnp.arange(n_periods), n_initial_states)

return discrete_grid_converter.internal_to_discrete_vars(out)
return discrete_grid_converter.internal_vars_to_vars(out)


# ======================================================================================
Expand Down
2 changes: 1 addition & 1 deletion src/lcm/solve_brute.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def solve(
list: List with one value function array per period.
"""
internal_params = discrete_grid_converter.params_to_internal(params)
internal_params = discrete_grid_converter.params_to_internal_params(params)

# extract information
n_periods = len(state_choice_spaces)
Expand Down
12 changes: 4 additions & 8 deletions tests/input_processing/test_discrete_state_conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def test_discrete_state_converter_internal_to_params(discrete_state_converter_kw
"next___c_index__": 1,
}
converter = DiscreteGridConverter(**discrete_state_converter_kwargs)
assert converter.internal_to_params(internal_params) == expected
assert converter.internal_params_to_params(internal_params) == expected


def test_discrete_state_converter_params_to_internal(discrete_state_converter_kwargs):
Expand All @@ -78,7 +78,7 @@ def test_discrete_state_converter_params_to_internal(discrete_state_converter_kw
"next_c": 1,
}
converter = DiscreteGridConverter(**discrete_state_converter_kwargs)
assert converter.params_to_internal(params) == expected
assert converter.params_to_internal_params(params) == expected


def test_discrete_state_converter_internal_to_discrete_vars(
Expand All @@ -89,9 +89,7 @@ def test_discrete_state_converter_internal_to_discrete_vars(
"__c_index__": jnp.array([0, 1]),
}
converter = DiscreteGridConverter(**discrete_state_converter_kwargs)
assert_array_equal(
converter.internal_to_discrete_vars(internal_states)["c"], expected
)
assert_array_equal(converter.internal_vars_to_vars(internal_states)["c"], expected)


def test_discrete_state_converter_discrete_vars_to_internal(
Expand All @@ -102,9 +100,7 @@ def test_discrete_state_converter_discrete_vars_to_internal(
"c": jnp.array([1, 0]),
}
converter = DiscreteGridConverter(**discrete_state_converter_kwargs)
assert_array_equal(
converter.discrete_vars_to_internal(states)["__c_index__"], expected
)
assert_array_equal(converter.vars_to_internal_vars(states)["__c_index__"], expected)


def test_discrete_state_converter_raises_error_if_keys_dont_match():
Expand Down

0 comments on commit b4e7ac6

Please sign in to comment.