Skip to content

Commit

Permalink
Adapt Tests
Browse files Browse the repository at this point in the history
  • Loading branch information
mj023 committed Nov 6, 2024
1 parent ff093d9 commit 8f419d3
Show file tree
Hide file tree
Showing 3 changed files with 5 additions and 4 deletions.
2 changes: 1 addition & 1 deletion src/lcm/entry_point.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def get_lcm_function(
targets: Literal["solve", "simulate", "solve_and_simulate"] = "solve",
*,
debug_mode: bool = True,
jit: bool = False,
jit: bool = True,
) -> tuple[Callable, ParamsDict]:
"""Entry point for users to get high level functions generated by lcm.
Expand Down
2 changes: 1 addition & 1 deletion tests/test_regression_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,4 +48,4 @@ def test_regression_test():
# Compare
# ==================================================================================
aaae(expected_solve, got_solve, decimal=5)
assert_frame_equal(expected_simulate, got_simulate)
assert_frame_equal(expected_simulate, got_simulate, check_like=True)
5 changes: 3 additions & 2 deletions tests/test_simulate.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,8 +94,8 @@ def test_simulate_using_raw_inputs(simulate_inputs):
**simulate_inputs,
)

assert_array_equal(got.loc[0, :]["retirement"], 1)
assert_array_almost_equal(got.loc[0, :]["consumption"], jnp.array([1.0, 50.400803]))
assert_array_equal(got["retirement"], 1)
assert_array_almost_equal(got["consumption"], jnp.array([1.0, 50.400803]))


# ======================================================================================
Expand Down Expand Up @@ -336,6 +336,7 @@ def f_b(b, params): # noqa: ARG001
params={"disutility_of_work": -1.0},
)
expected = {
**processed_results,
"fa": jnp.arange(3) - 1.0,
"fb": 1 + jnp.arange(3),
}
Expand Down

0 comments on commit 8f419d3

Please sign in to comment.