Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RTR] Fix sol.integrate shooting.single continuity #859

Merged
merged 18 commits into from
Mar 15, 2024
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 45 additions & 5 deletions bioptim/optimization/solution/solution.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,8 +230,17 @@
"an InitialGuess[List] of len 4 (states, controls, parameters, algebraic_states), "
"or a None"
)
if sum([len(s) != len(all_ns) if p != 4 else False for p, s in enumerate(sol)]) != 0:

if len(sol[0]) != len(all_ns):
raise ValueError("The time step dt array len must match the number of phases")

Check warning on line 235 in bioptim/optimization/solution/solution.py

View check run for this annotation

Codecov / codecov/patch

bioptim/optimization/solution/solution.py#L235

Added line #L235 was not covered by tests

is_right_size = [
len(s) != len(all_ns) if p != 3 and len(sol[p + 1].keys()) != 0 else False for p, s in enumerate(sol[:1])
]

if sum(is_right_size) != 0:
raise ValueError("The InitialGuessList len must match the number of phases")

if n_param != 0:
if len(sol) != 3 and len(sol[3]) != 1 and sol[3][0].shape != (n_param, 1):
raise ValueError(
Expand Down Expand Up @@ -281,7 +290,8 @@
# For parameters
if n_param:
for p, ss in enumerate(sol_params):
vector = np.concatenate((vector, np.repeat(ss.init, all_ns[p] + 1)[:, np.newaxis]))
for key in ss.keys():
vector = np.concatenate((vector, np.repeat(ss[key].init, 1)[:, np.newaxis]))

Check warning on line 294 in bioptim/optimization/solution/solution.py

View check run for this annotation

Codecov / codecov/patch

bioptim/optimization/solution/solution.py#L293-L294

Added lines #L293 - L294 were not covered by tests

# For algebraic_states variables
for p, ss in enumerate(sol_algebraic_states):
Expand Down Expand Up @@ -378,6 +388,7 @@
to_merge: SolutionMerge | list[SolutionMerge] = None,
time_alignment: TimeAlignment = TimeAlignment.STATES,
continuous: bool = True,
duplicated_times: bool = True,
) -> list | np.ndarray:
"""
Returns the time vector at each node that matches stepwise_states or stepwise_controls
Expand All @@ -394,6 +405,9 @@
continuous: bool
If the time should be continuous throughout the whole ocp. If False, then the time is reset at the
beginning of each phase.
duplicated_times: bool
If the times should be duplicated for each nodes.
If False, then the returned time vector will not have any duplicated times

Returns
-------
Expand All @@ -405,6 +419,7 @@
to_merge=to_merge,
time_alignment=time_alignment,
continuous=continuous,
duplicated_times=duplicated_times,
)

def _process_time_vector(
Expand All @@ -413,6 +428,7 @@
to_merge: SolutionMerge | list[SolutionMerge],
time_alignment: TimeAlignment,
continuous: bool,
duplicated_times: bool = True,
):
if to_merge is None or isinstance(to_merge, SolutionMerge):
to_merge = [to_merge]
Expand Down Expand Up @@ -467,6 +483,14 @@
else:
raise ValueError("time_alignment should be either TimeAlignment.STATES or TimeAlignment.CONTROLS")

if not duplicated_times:
for i in range(len(times)):
for j in range(len(times[i])):
keep_condition = times[i][j].shape[0] == 1 and i == len(times) - 1
times[i][j] = times[i][j][:] if keep_condition else times[i][j][:-1]
if j == len(times[i]) - 1 and i != len(times) - 1:
del times[i][j]

if continuous:
for phase_idx, phase_time in enumerate(times):
if phase_idx == 0:
Expand Down Expand Up @@ -686,6 +710,8 @@
shooting_type: Shooting = Shooting.SINGLE,
integrator: SolutionIntegrator = SolutionIntegrator.OCP,
to_merge: SolutionMerge | list[SolutionMerge] = None,
duplicated_times: bool = True,
return_time: bool = False,
):
has_direct_collocation = sum([nlp.ode_solver.is_direct_collocation for nlp in self.ocp.nlp]) > 0
if has_direct_collocation and integrator == SolutionIntegrator.OCP:
Expand Down Expand Up @@ -717,6 +743,7 @@
integrated_sol = None
for p, nlp in enumerate(self.ocp.nlp):
next_x = self._states_for_phase_integration(shooting_type, p, integrated_sol, x, u, params, a)

integrated_sol = solve_ivp_interface(
shooting_type=shooting_type,
nlp=nlp,
Expand All @@ -732,12 +759,24 @@
for key in nlp.states.keys():
out[p][key] = [None] * nlp.n_states_nodes
for ns, sol_ns in enumerate(integrated_sol):
out[p][key][ns] = sol_ns[nlp.states[key].index, :]
if duplicated_times:
out[p][key][ns] = sol_ns[nlp.states[key].index, :]
else:
duplicated_times_condition = p == len(self.ocp.nlp) - 1 and ns == nlp.ns
out[p][key][ns] = (
sol_ns[nlp.states[key].index, :]
if duplicated_times_condition
else sol_ns[nlp.states[key].index, :-1]
)

if to_merge:
out = SolutionData.from_unscaled(self.ocp, out, "x").to_dict(to_merge=to_merge, scaled=False)

return out if len(out) > 1 else out[0]
if return_time:
time_vector = self.stepwise_time(to_merge=to_merge, duplicated_times=duplicated_times)
return out if len(out) > 1 else out[0], time_vector if len(time_vector) > 1 else time_vector[0]
else:
return out if len(out) > 1 else out[0]

def _states_for_phase_integration(
self,
Expand Down Expand Up @@ -791,6 +830,7 @@
# based on the phase transition objective or constraint function. That is why we need to concatenate
# twice the last state
x = PenaltyHelpers.states(penalty, 0, lambda p, n, sn: integrated_states[-1])

u = PenaltyHelpers.controls(
penalty,
0,
Expand All @@ -812,7 +852,7 @@
f"please integrate with Shooting.SINGLE_DISCONTINUOUS_PHASE."
)

return [decision_states[phase_idx][0] + dx]
return [(integrated_states[-1] if shooting_type == Shooting.SINGLE else decision_states[phase_idx][0]) + dx]

def _integrate_stepwise(self) -> None:
"""
Expand Down
12 changes: 12 additions & 0 deletions tests/shard4/test_simulate.py
Original file line number Diff line number Diff line change
Expand Up @@ -606,6 +606,18 @@ def test_integrate_multiphase(shooting, integrator, ode_solver, phase_dynamics,
assert sol_integrated[i][key].shape == (shapes[k], n_shooting[i] * n_steps + 1)
assert states[i][key].shape == (shapes[k], n_shooting[i] * n_steps + 1)

opts = {
"shooting_type": shooting,
"integrator": integrator,
"to_merge": [SolutionMerge.PHASES, SolutionMerge.NODES],
}
sol_integrated, sol_time = sol.integrate(**opts, duplicated_times=False, return_time=True)
sol_time = np.concatenate(sol_time, axis=0)
assert len(sol_time) == len(np.unique(sol_time))
for i in range(len(sol_integrated)):
for k, key in enumerate(states[i]):
assert len(sol_integrated[key][k]) == len(sol_time)


def test_check_models_comes_from_same_super_class():
from bioptim.examples.getting_started import example_multiphase as ocp_module
Expand Down
4 changes: 2 additions & 2 deletions tests/shard6/test_global_stochastic_except_collocation.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def test_arm_reaching_muscle_driven(use_sx):
match = None
else:
match = re.escape(
"Error in Function::call for 'tp' [MXFunction] at .../casadi/core/function.cpp:339:\n"
"Error in Function::call for 'tp' [MXFunction] at .../casadi/core/function.cpp:370:\n"
".../casadi/core/linsol_internal.cpp:65: eval_sx not defined for LinsolQr"
)
with pytest.raises(RuntimeError, match=match):
Expand Down Expand Up @@ -282,7 +282,7 @@ def test_arm_reaching_torque_driven_explicit(use_sx):
match = None
else:
match = re.escape(
"Error in Function::call for 'tp' [MXFunction] at .../casadi/core/function.cpp:339:\n"
"Error in Function::call for 'tp' [MXFunction] at .../casadi/core/function.cpp:370:\n"
".../casadi/core/linsol_internal.cpp:65: eval_sx not defined for LinsolQr"
)
with pytest.raises(RuntimeError, match=match):
Expand Down
Loading