Skip to content

Commit

Permalink
Merge pull request #900 from EveCharbie/time
Browse files Browse the repository at this point in the history
dt dependent integration gub fix
  • Loading branch information
pariterre authored Nov 8, 2024
2 parents fd3b8d1 + c30a55a commit 4c44763
Show file tree
Hide file tree
Showing 5 changed files with 435 additions and 29 deletions.
55 changes: 36 additions & 19 deletions bioptim/dynamics/integrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,10 @@ def _time_xall_from_dt_func(self) -> Function:
def h(self):
return self.t_span_sym[1] / self._n_step

@property
def dt(self):
return self.t_span_sym[1]

def next_x(self, t0: float | MX | SX, x_prev: MX | SX, u: MX | SX, p: MX | SX, a: MX | SX, d: MX | SX) -> MX | SX:
"""
Compute the next integrated state (abstract)
Expand Down Expand Up @@ -323,7 +327,7 @@ class RK1(RK):
"""

def next_x(self, t0: float | MX | SX, x_prev: MX | SX, u: MX | SX, p: MX | SX, a: MX | SX, d: MX | SX) -> MX | SX:
return x_prev + self.h * self.fun(vertcat(t0, self.h), x_prev, self.get_u(u, t0), p, a, d)[:, self.ode_idx]
return x_prev + self.h * self.fun(vertcat(t0, self.dt), x_prev, self.get_u(u, t0), p, a, d)[:, self.ode_idx]


class RK2(RK):
Expand All @@ -333,12 +337,15 @@ class RK2(RK):

def next_x(self, t0: float | MX | SX, x_prev: MX | SX, u: MX | SX, p: MX | SX, a: MX | SX, d: MX | SX) -> MX | SX:
h = self.h
dt = self.dt

k1 = self.fun(vertcat(t0, h), x_prev, self.get_u(u, t0), p, a, d)[:, self.ode_idx]
k1 = self.fun(vertcat(t0, dt), x_prev, self.get_u(u, t0), p, a, d)[:, self.ode_idx]
return (
x_prev
+ h
* self.fun(vertcat(t0 + h / 2, h), x_prev + h / 2 * k1, self.get_u(u, t0 + h / 2), p, a, d)[:, self.ode_idx]
* self.fun(vertcat(t0 + h / 2, dt), x_prev + h / 2 * k1, self.get_u(u, t0 + h / 2), p, a, d)[
:, self.ode_idx
]
)


Expand All @@ -349,11 +356,12 @@ class RK4(RK):

def next_x(self, t0: float | MX | SX, x_prev: MX | SX, u: MX | SX, p: MX | SX, a: MX | SX, d: MX | SX) -> MX | SX:
h = self.h
dt = self.dt

k1 = self.fun(vertcat(t0, h), x_prev, self.get_u(u, t0), p, a, d)[:, self.ode_idx]
k2 = self.fun(vertcat(t0 + h / 2, h), x_prev + h / 2 * k1, self.get_u(u, t0 + h / 2), p, a, d)[:, self.ode_idx]
k3 = self.fun(vertcat(t0 + h / 2, h), x_prev + h / 2 * k2, self.get_u(u, t0 + h / 2), p, a, d)[:, self.ode_idx]
k4 = self.fun(vertcat(t0 + h, h), x_prev + h * k3, self.get_u(u, t0 + h), p, a, d)[:, self.ode_idx]
k1 = self.fun(vertcat(t0, dt), x_prev, self.get_u(u, t0), p, a, d)[:, self.ode_idx]
k2 = self.fun(vertcat(t0 + h / 2, dt), x_prev + h / 2 * k1, self.get_u(u, t0 + h / 2), p, a, d)[:, self.ode_idx]
k3 = self.fun(vertcat(t0 + h / 2, dt), x_prev + h / 2 * k2, self.get_u(u, t0 + h / 2), p, a, d)[:, self.ode_idx]
k4 = self.fun(vertcat(t0 + h, dt), x_prev + h * k3, self.get_u(u, t0 + h), p, a, d)[:, self.ode_idx]
return x_prev + h / 6 * (k1 + 2 * k2 + 2 * k3 + k4)


Expand All @@ -364,74 +372,75 @@ class RK8(RK4):

def next_x(self, t0: float | MX | SX, x_prev: MX | SX, u: MX | SX, p: MX | SX, a: MX | SX, d: MX | SX) -> MX | SX:
h = self.h
dt = self.dt

k1 = self.fun(vertcat(t0, h), x_prev, self.get_u(u, t0), p, a, d)[:, self.ode_idx]
k1 = self.fun(vertcat(t0, dt), x_prev, self.get_u(u, t0), p, a, d)[:, self.ode_idx]
k2 = self.fun(
vertcat(t0 + h * 4 / 27, h),
vertcat(t0 + h * 4 / 27, dt),
x_prev + (h * 4 / 27) * k1,
self.get_u(u, t0 + h * (4 / 27)),
p,
a,
d,
)[:, self.ode_idx]
k3 = self.fun(
vertcat(t0 + h / 18, h),
vertcat(t0 + h / 18, dt),
x_prev + (h / 18) * (k1 + 3 * k2),
self.get_u(u, t0 + h * (2 / 9)),
p,
a,
d,
)[:, self.ode_idx]
k4 = self.fun(
vertcat(t0 + h / 12, h),
vertcat(t0 + h / 12, dt),
x_prev + (h / 12) * (k1 + 3 * k3),
self.get_u(u, t0 + h * (1 / 3)),
p,
a,
d,
)[:, self.ode_idx]
k5 = self.fun(
vertcat(t0 + h / 8, h),
vertcat(t0 + h / 8, dt),
x_prev + (h / 8) * (k1 + 3 * k4),
self.get_u(u, t0 + h * (1 / 2)),
p,
a,
d,
)[:, self.ode_idx]
k6 = self.fun(
vertcat(t0 + h / 54, h),
vertcat(t0 + h / 54, dt),
x_prev + (h / 54) * (13 * k1 - 27 * k3 + 42 * k4 + 8 * k5),
self.get_u(u, t0 + h * (2 / 3)),
p,
a,
d,
)[:, self.ode_idx]
k7 = self.fun(
vertcat(t0 + h / 4320, h),
vertcat(t0 + h / 4320, dt),
x_prev + (h / 4320) * (389 * k1 - 54 * k3 + 966 * k4 - 824 * k5 + 243 * k6),
self.get_u(u, t0 + h * (1 / 6)),
p,
a,
d,
)[:, self.ode_idx]
k8 = self.fun(
vertcat(t0 + h / 20, h),
vertcat(t0 + h / 20, dt),
x_prev + (h / 20) * (-234 * k1 + 81 * k3 - 1164 * k4 + 656 * k5 - 122 * k6 + 800 * k7),
self.get_u(u, t0 + h),
p,
a,
d,
)[:, self.ode_idx]
k9 = self.fun(
vertcat(t0 + h / 288, h),
vertcat(t0 + h / 288, dt),
x_prev + (h / 288) * (-127 * k1 + 18 * k3 - 678 * k4 + 456 * k5 - 9 * k6 + 576 * k7 + 4 * k8),
self.get_u(u, t0 + h * (5 / 6)),
p,
a,
d,
)[:, self.ode_idx]
k10 = self.fun(
vertcat(t0 + h / 820, h),
vertcat(t0 + h / 820, dt),
x_prev
+ (h / 820) * (1481 * k1 - 81 * k3 + 7104 * k4 - 3376 * k5 + 72 * k6 - 5040 * k7 - 60 * k8 + 720 * k9),
self.get_u(u, t0 + h),
Expand Down Expand Up @@ -467,8 +476,8 @@ def next_x(
d_prev: MX | SX,
d_next: MX | SX,
):
dx = self.fun(vertcat(t0, self.h), x_prev, u_prev, p, a_prev, d_prev)[:, self.ode_idx]
dx_next = self.fun(vertcat(t0 + self.h, self.h), x_next, u_next, p, a_next, d_next)[:, self.ode_idx]
dx = self.fun(vertcat(t0, self.dt), x_prev, u_prev, p, a_prev, d_prev)[:, self.ode_idx]
dx_next = self.fun(vertcat(t0 + self.h, self.dt), x_next, u_next, p, a_next, d_next)[:, self.ode_idx]
return x_prev + (dx + dx_next) * self.h / 2

@property
Expand All @@ -491,6 +500,10 @@ def shape_xall(self):
def h(self):
return self.t_span_sym[1]

@property
def dt(self):
return self.t_span_sym[1]

def dxdt(
self,
states: MX | SX,
Expand Down Expand Up @@ -618,6 +631,10 @@ def _output_names(self):
def h(self):
return self.t_span_sym[1]

@property
def dt(self):
return self.t_span_sym[1]

@property
def _integration_time(self):
return [0] + collocation_points(self.degree, self.method)
Expand Down
2 changes: 1 addition & 1 deletion bioptim/optimization/non_linear_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -520,7 +520,7 @@ def node_time(self, node_idx: int):
"""
if node_idx < 0 or node_idx > self.ns:
return ValueError(f"node_index out of range [0:{self.ns}]")
return self.tf / self.ns * node_idx
return self.dt * node_idx

def get_var_from_states_or_controls(
self, key: str, states: MX.sym, controls: MX.sym, algebraic_states: MX.sym = None
Expand Down
2 changes: 1 addition & 1 deletion bioptim/optimization/optimal_control_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -1771,7 +1771,7 @@ def node_time(self, phase_idx: int, node_idx: int):
raise ValueError(f"node_index out of range [0:{self.nlp[phase_idx].ns}]")
previous_phase_time = sum([nlp.tf for nlp in self.nlp[:phase_idx]])

return previous_phase_time + self.nlp[phase_idx].tf * node_idx / self.nlp[phase_idx].ns
return previous_phase_time + self.nlp[phase_idx].dt * node_idx

def _set_default_ode_solver(self):
"""
Expand Down
Loading

0 comments on commit 4c44763

Please sign in to comment.