Skip to content

Commit

Permalink
Merge pull request #859 from Kev1CO/integrate_shooting_single_correction
Browse files Browse the repository at this point in the history
Fix sol.integrate shooting.single continuity
  • Loading branch information
pariterre authored Mar 15, 2024
2 parents a377e94 + 4cda59c commit 45ff432
Show file tree
Hide file tree
Showing 3 changed files with 125 additions and 23 deletions.
102 changes: 97 additions & 5 deletions bioptim/optimization/solution/solution.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,8 +230,17 @@ def from_initial_guess(cls, ocp, sol: list):
"an InitialGuess[List] of len 4 (states, controls, parameters, algebraic_states), "
"or a None"
)
if sum([len(s) != len(all_ns) if p != 4 else False for p, s in enumerate(sol)]) != 0:

if len(sol[0]) != len(all_ns):
raise ValueError("The time step dt array len must match the number of phases")

is_right_size = [
len(s) != len(all_ns) if p != 3 and len(sol[p + 1].keys()) != 0 else False for p, s in enumerate(sol[:1])
]

if sum(is_right_size) != 0:
raise ValueError("The InitialGuessList len must match the number of phases")

if n_param != 0:
if len(sol) != 3 and len(sol[3]) != 1 and sol[3][0].shape != (n_param, 1):
raise ValueError(
Expand Down Expand Up @@ -281,7 +290,8 @@ def from_initial_guess(cls, ocp, sol: list):
# For parameters
if n_param:
for p, ss in enumerate(sol_params):
vector = np.concatenate((vector, np.repeat(ss.init, all_ns[p] + 1)[:, np.newaxis]))
for key in ss.keys():
vector = np.concatenate((vector, np.repeat(ss[key].init, 1)[:, np.newaxis]))

# For algebraic_states variables
for p, ss in enumerate(sol_algebraic_states):
Expand Down Expand Up @@ -378,6 +388,7 @@ def stepwise_time(
to_merge: SolutionMerge | list[SolutionMerge] = None,
time_alignment: TimeAlignment = TimeAlignment.STATES,
continuous: bool = True,
duplicated_times: bool = True,
) -> list | np.ndarray:
"""
Returns the time vector at each node that matches stepwise_states or stepwise_controls
Expand All @@ -394,6 +405,9 @@ def stepwise_time(
continuous: bool
If the time should be continuous throughout the whole ocp. If False, then the time is reset at the
beginning of each phase.
duplicated_times: bool
If the times should be duplicated for each nodes.
If False, then the returned time vector will not have any duplicated times
Returns
-------
Expand All @@ -405,6 +419,7 @@ def stepwise_time(
to_merge=to_merge,
time_alignment=time_alignment,
continuous=continuous,
duplicated_times=duplicated_times,
)

def _process_time_vector(
Expand All @@ -413,6 +428,7 @@ def _process_time_vector(
to_merge: SolutionMerge | list[SolutionMerge],
time_alignment: TimeAlignment,
continuous: bool,
duplicated_times: bool = True,
):
if to_merge is None or isinstance(to_merge, SolutionMerge):
to_merge = [to_merge]
Expand Down Expand Up @@ -467,6 +483,15 @@ def _process_time_vector(
else:
raise ValueError("time_alignment should be either TimeAlignment.STATES or TimeAlignment.CONTROLS")

if not duplicated_times:
for i in range(len(times)):
for j in range(len(times[i])):
# Last node of last phase is always kept
keep_condition = times[i][j].shape[0] == 1 and i == len(times) - 1
times[i][j] = times[i][j][:] if keep_condition else times[i][j][:-1]
if j == len(times[i]) - 1 and i != len(times) - 1:
del times[i][j]

if continuous:
for phase_idx, phase_time in enumerate(times):
if phase_idx == 0:
Expand Down Expand Up @@ -686,7 +711,31 @@ def integrate(
shooting_type: Shooting = Shooting.SINGLE,
integrator: SolutionIntegrator = SolutionIntegrator.OCP,
to_merge: SolutionMerge | list[SolutionMerge] = None,
duplicated_times: bool = True,
return_time: bool = False,
):
"""
Create a deepcopy of the Solution
Parameters
----------
shooting_type: Shooting
The integration shooting type to use
integrator: SolutionIntegrator
The type of integrator to use
to_merge: SolutionMerge | list[SolutionMerge, ...]
The type of merge to perform. If None, then no merge is performed.
duplicated_times: bool
If the times should be duplicated for each node.
If False, then the returned time vector will not have any duplicated times.
return_time: bool
If the time vector should be returned
Returns
-------
Return the integrated states
"""

has_direct_collocation = sum([nlp.ode_solver.is_direct_collocation for nlp in self.ocp.nlp]) > 0
if has_direct_collocation and integrator == SolutionIntegrator.OCP:
raise ValueError(
Expand Down Expand Up @@ -717,6 +766,7 @@ def integrate(
integrated_sol = None
for p, nlp in enumerate(self.ocp.nlp):
next_x = self._states_for_phase_integration(shooting_type, p, integrated_sol, x, u, params, a)

integrated_sol = solve_ivp_interface(
shooting_type=shooting_type,
nlp=nlp,
Expand All @@ -732,12 +782,25 @@ def integrate(
for key in nlp.states.keys():
out[p][key] = [None] * nlp.n_states_nodes
for ns, sol_ns in enumerate(integrated_sol):
out[p][key][ns] = sol_ns[nlp.states[key].index, :]
if duplicated_times:
out[p][key][ns] = sol_ns[nlp.states[key].index, :]
else:
# Last node of last phase is always kept
duplicated_times_condition = p == len(self.ocp.nlp) - 1 and ns == nlp.ns
out[p][key][ns] = (
sol_ns[nlp.states[key].index, :]
if duplicated_times_condition
else sol_ns[nlp.states[key].index, :-1]
)

if to_merge:
out = SolutionData.from_unscaled(self.ocp, out, "x").to_dict(to_merge=to_merge, scaled=False)

return out if len(out) > 1 else out[0]
if return_time:
time_vector = self._return_time_vector(to_merge=to_merge, duplicated_times=duplicated_times)
return out if len(out) > 1 else out[0], time_vector if len(time_vector) > 1 else time_vector[0]
else:
return out if len(out) > 1 else out[0]

def _states_for_phase_integration(
self,
Expand Down Expand Up @@ -791,6 +854,7 @@ def _states_for_phase_integration(
# based on the phase transition objective or constraint function. That is why we need to concatenate
# twice the last state
x = PenaltyHelpers.states(penalty, 0, lambda p, n, sn: integrated_states[-1])

u = PenaltyHelpers.controls(
penalty,
0,
Expand All @@ -812,7 +876,7 @@ def _states_for_phase_integration(
f"please integrate with Shooting.SINGLE_DISCONTINUOUS_PHASE."
)

return [decision_states[phase_idx][0] + dx]
return [(integrated_states[-1] if shooting_type == Shooting.SINGLE else decision_states[phase_idx][0]) + dx]

def _integrate_stepwise(self) -> None:
"""
Expand Down Expand Up @@ -854,6 +918,34 @@ def _integrate_stepwise(self) -> None:

self._stepwise_states = SolutionData.from_unscaled(self.ocp, unscaled, "x")

def _return_time_vector(self, to_merge: SolutionMerge | list[SolutionMerge], duplicated_times: bool):
"""
Returns the time vector at each node that matches stepwise_states or stepwise_controls
Parameters
----------
to_merge: SolutionMerge | list[SolutionMerge, ...]
The merge type to perform. If None, then no merge is performed.
duplicated_times: bool
If the times should be duplicated for each node.
If False, then the returned time vector will not have any duplicated times.
Returns
-------
The time vector at each node that matches stepwise_states or stepwise_controls
"""
if to_merge is None:
to_merge = []
if isinstance(to_merge, SolutionMerge):
to_merge = [to_merge]
if SolutionMerge.NODES and SolutionMerge.PHASES in to_merge:
time_vector = np.concatenate(self.stepwise_time(to_merge=to_merge, duplicated_times=duplicated_times))
elif SolutionMerge.NODES in to_merge:
time_vector = self.stepwise_time(to_merge=to_merge, duplicated_times=duplicated_times)
for i in range(len(self.ocp.nlp)):
time_vector[i] = np.concatenate(time_vector[i])
else:
time_vector = self.stepwise_time(to_merge=to_merge, duplicated_times=duplicated_times)
return time_vector

def interpolate(self, n_frames: int | list | tuple, scaled: bool = False):
"""
Interpolate the states
Expand Down
26 changes: 26 additions & 0 deletions tests/shard4/test_simulate.py
Original file line number Diff line number Diff line change
Expand Up @@ -606,6 +606,32 @@ def test_integrate_multiphase(shooting, integrator, ode_solver, phase_dynamics,
assert sol_integrated[i][key].shape == (shapes[k], n_shooting[i] * n_steps + 1)
assert states[i][key].shape == (shapes[k], n_shooting[i] * n_steps + 1)

sol_integrated, sol_time = sol.integrate(
shooting_type=shooting,
integrator=integrator,
to_merge=[SolutionMerge.PHASES, SolutionMerge.NODES],
duplicated_times=False,
return_time=True,
)

assert len(sol_time) == len(np.unique(sol_time))
for i in range(len(sol_integrated)):
for k, key in enumerate(states[i]):
assert len(sol_integrated[key][k]) == len(sol_time)

sol_integrated, sol_time = sol.integrate(
shooting_type=shooting,
integrator=integrator,
to_merge=[SolutionMerge.NODES],
duplicated_times=False,
return_time=True,
)
for i in range(len(sol_time)):
assert len(sol_time[i]) == len(np.unique(sol_time[i]))
for i in range(len(sol_integrated)):
for k, key in enumerate(states[i]):
assert len(sol_integrated[i][key][k]) == len(sol_time[i])


def test_check_models_comes_from_same_super_class():
from bioptim.examples.getting_started import example_multiphase as ocp_module
Expand Down
20 changes: 2 additions & 18 deletions tests/shard6/test_global_stochastic_except_collocation.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,15 +30,7 @@ def test_arm_reaching_muscle_driven(use_sx):
sensory_noise_magnitude = vertcat(wPq_magnitude, wPqdot_magnitude)

if use_sx:
if platform.system() == "Windows":
# It is not possible to test the error message on Windows as it uses absolute path
match = None
else:
match = re.escape(
"Error in Function::call for 'tp' [MXFunction] at .../casadi/core/function.cpp:339:\n"
".../casadi/core/linsol_internal.cpp:65: eval_sx not defined for LinsolQr"
)
with pytest.raises(RuntimeError, match=match):
with pytest.raises(RuntimeError, match=".*eval_sx not defined for LinsolQr"):
ocp = ocp_module.prepare_socp(
final_time=final_time,
n_shooting=n_shooting,
Expand Down Expand Up @@ -277,15 +269,7 @@ def test_arm_reaching_torque_driven_explicit(use_sx):
bioptim_folder = os.path.dirname(ocp_module.__file__)

if use_sx:
if platform.system() == "Windows":
# It is not possible to test the error message on Windows as it uses absolute path
match = None
else:
match = re.escape(
"Error in Function::call for 'tp' [MXFunction] at .../casadi/core/function.cpp:339:\n"
".../casadi/core/linsol_internal.cpp:65: eval_sx not defined for LinsolQr"
)
with pytest.raises(RuntimeError, match=match):
with pytest.raises(RuntimeError, match=".*eval_sx not defined for LinsolQr"):
ocp = ocp_module.prepare_socp(
biorbd_model_path=bioptim_folder + "/models/LeuvenArmModel.bioMod",
final_time=final_time,
Expand Down

0 comments on commit 45ff432

Please sign in to comment.