Skip to content

Commit

Permalink
Merge pull request #850 from Kev1CO/rk_correction
Browse files Browse the repository at this point in the history
Correcting RK4 and RK8 time integration calculation
  • Loading branch information
pariterre authored Feb 21, 2024
2 parents 770d395 + faf2b52 commit acc9081
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 22 deletions.
56 changes: 35 additions & 21 deletions bioptim/dynamics/integrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,7 +305,7 @@ class RK1(RK):
"""

def next_x(self, t0: float | MX | SX, x_prev: MX | SX, u: MX | SX, p: MX | SX, a: MX | SX) -> MX | SX:
return x_prev + self.h * self.fun(t0, x_prev, self.get_u(u, t0), p, a)[:, self.ode_idx]
return x_prev + self.h * self.fun(vertcat(t0, self.h), x_prev, self.get_u(u, t0), p, a)[:, self.ode_idx]


class RK2(RK):
Expand All @@ -317,7 +317,11 @@ def next_x(self, t0: float | MX | SX, x_prev: MX | SX, u: MX | SX, p: MX | SX, a
h = self.h

k1 = self.fun(vertcat(t0, h), x_prev, self.get_u(u, t0), p, a)[:, self.ode_idx]
return x_prev + h * self.fun(t0, x_prev + h / 2 * k1, self.get_u(u, t0 + h / 2), p, a)[:, self.ode_idx]
return (
x_prev
+ h
* self.fun(vertcat(t0 + h / 2, h), x_prev + h / 2 * k1, self.get_u(u, t0 + h / 2), p, a)[:, self.ode_idx]
)


class RK4(RK):
Expand All @@ -327,12 +331,11 @@ class RK4(RK):

def next_x(self, t0: float | MX | SX, x_prev: MX | SX, u: MX | SX, p: MX | SX, a: MX | SX):
h = self.h
t = vertcat(t0, h)

k1 = self.fun(t, x_prev, self.get_u(u, t0), p, a)[:, self.ode_idx]
k2 = self.fun(t, x_prev + h / 2 * k1, self.get_u(u, t0 + h / 2), p, a)[:, self.ode_idx]
k3 = self.fun(t, x_prev + h / 2 * k2, self.get_u(u, t0 + h / 2), p, a)[:, self.ode_idx]
k4 = self.fun(t, x_prev + h * k3, self.get_u(u, t0 + h), p, a)[:, self.ode_idx]
k1 = self.fun(vertcat(t0, h), x_prev, self.get_u(u, t0), p, a)[:, self.ode_idx]
k2 = self.fun(vertcat(t0 + h / 2, h), x_prev + h / 2 * k1, self.get_u(u, t0 + h / 2), p, a)[:, self.ode_idx]
k3 = self.fun(vertcat(t0 + h / 2, h), x_prev + h / 2 * k2, self.get_u(u, t0 + h / 2), p, a)[:, self.ode_idx]
k4 = self.fun(vertcat(t0 + h, h), x_prev + h * k3, self.get_u(u, t0 + h), p, a)[:, self.ode_idx]
return x_prev + h / 6 * (k1 + 2 * k2 + 2 * k3 + k4)


Expand All @@ -343,39 +346,50 @@ class RK8(RK4):

def next_x(self, t0: float | MX | SX, x_prev: MX | SX, u: MX | SX, p: MX | SX, a: MX | SX):
h = self.h
t = vertcat(t0, h)

k1 = self.fun(t, x_prev, self.get_u(u, t0), p, a)[:, self.ode_idx]
k2 = self.fun(t, x_prev + (h * 4 / 27) * k1, self.get_u(u, t0 + h * (4 / 27)), p, a)[:, self.ode_idx]
k3 = self.fun(t, x_prev + (h / 18) * (k1 + 3 * k2), self.get_u(u, t0 + h * (2 / 9)), p, a)[:, self.ode_idx]
k4 = self.fun(t, x_prev + (h / 12) * (k1 + 3 * k3), self.get_u(u, t0 + h * (1 / 3)), p, a)[:, self.ode_idx]
k5 = self.fun(t, x_prev + (h / 8) * (k1 + 3 * k4), self.get_u(u, t0 + h * (1 / 2)), p, a)[:, self.ode_idx]
k1 = self.fun(vertcat(t0, h), x_prev, self.get_u(u, t0), p, a)[:, self.ode_idx]
k2 = self.fun(vertcat(t0 + h * 4 / 27, h), x_prev + (h * 4 / 27) * k1, self.get_u(u, t0 + h * (4 / 27)), p, a)[
:, self.ode_idx
]
k3 = self.fun(
vertcat(t0 + h / 18, h), x_prev + (h / 18) * (k1 + 3 * k2), self.get_u(u, t0 + h * (2 / 9)), p, a
)[:, self.ode_idx]
k4 = self.fun(
vertcat(t0 + h / 12, h), x_prev + (h / 12) * (k1 + 3 * k3), self.get_u(u, t0 + h * (1 / 3)), p, a
)[:, self.ode_idx]
k5 = self.fun(vertcat(t0 + h / 8, h), x_prev + (h / 8) * (k1 + 3 * k4), self.get_u(u, t0 + h * (1 / 2)), p, a)[
:, self.ode_idx
]
k6 = self.fun(
t, x_prev + (h / 54) * (13 * k1 - 27 * k3 + 42 * k4 + 8 * k5), self.get_u(u, t0 + h * (2 / 3)), p, a
vertcat(t0 + h / 54, h),
x_prev + (h / 54) * (13 * k1 - 27 * k3 + 42 * k4 + 8 * k5),
self.get_u(u, t0 + h * (2 / 3)),
p,
a,
)[:, self.ode_idx]
k7 = self.fun(
t,
vertcat(t0 + h / 4320, h),
x_prev + (h / 4320) * (389 * k1 - 54 * k3 + 966 * k4 - 824 * k5 + 243 * k6),
self.get_u(u, t0 + h * (1 / 6)),
p,
a,
)[:, self.ode_idx]
k8 = self.fun(
t,
vertcat(t0 + h / 20, h),
x_prev + (h / 20) * (-234 * k1 + 81 * k3 - 1164 * k4 + 656 * k5 - 122 * k6 + 800 * k7),
self.get_u(u, t0 + h),
p,
a,
)[:, self.ode_idx]
k9 = self.fun(
t,
vertcat(t0 + h / 288, h),
x_prev + (h / 288) * (-127 * k1 + 18 * k3 - 678 * k4 + 456 * k5 - 9 * k6 + 576 * k7 + 4 * k8),
self.get_u(u, t0 + h * (5 / 6)),
p,
a,
)[:, self.ode_idx]
k10 = self.fun(
t,
vertcat(t0 + h / 820, h),
x_prev
+ (h / 820) * (1481 * k1 - 81 * k3 + 7104 * k4 - 3376 * k5 + 72 * k6 - 5040 * k7 - 60 * k8 + 720 * k9),
self.get_u(u, t0 + h),
Expand Down Expand Up @@ -408,8 +422,8 @@ def next_x(
a_prev: MX | SX,
a_next: MX | SX,
):
dx = self.fun(t0, x_prev, u_prev, p, a_prev)[:, self.ode_idx]
dx_next = self.fun(t0, x_next, u_next, p, a_next)[:, self.ode_idx]
dx = self.fun(vertcat(t0, self.h), x_prev, u_prev, p, a_prev)[:, self.ode_idx]
dx_next = self.fun(vertcat(t0 + self.h, self.h), x_next, u_next, p, a_next)[:, self.ode_idx]
return x_prev + (dx + dx_next) * self.h / 2

@property
Expand Down Expand Up @@ -597,7 +611,7 @@ def dxdt(
states_end = self._d[0] * states[1]
defects = []
for j in range(1, self.degree + 1):
t = vertcat(self.t_span_sym[0] + self._integration_time[j - 1] * self.h, self.h)
t = vertcat(self.t_span_sym[0] + self._integration_time[j] * self.h, self.h)

# Expression for the state derivative at the collocation point
xp_j = 0
Expand Down
2 changes: 1 addition & 1 deletion bioptim/gui/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -642,7 +642,7 @@ def find_phases_intersections(self):
Finds the intersection between the phases
"""

return list(accumulate([t[-1][-1] for t in self.t_integrated]))[:-1]
return list([t[-1][-1] for t in self.t_integrated])[:-1]

@staticmethod
def show():
Expand Down

0 comments on commit acc9081

Please sign in to comment.