Skip to content

Commit

Permalink
Removed "..." in type hints of list
Browse files Browse the repository at this point in the history
  • Loading branch information
pariterre committed Jun 19, 2024
1 parent 5bd32ea commit 50109c0
Show file tree
Hide file tree
Showing 12 changed files with 44 additions and 48 deletions.
6 changes: 3 additions & 3 deletions bioptim/dynamics/configure_new_variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,7 @@ def check_variable_copy_condition(
and name in getattr(nlp[use_from_phase_idx], decision_variable_attribute)
)

def define_cx_scaled(self, n_col: int, n_shooting: int, initial_node) -> list[MX | SX, ...]:
def define_cx_scaled(self, n_col: int, n_shooting: int, initial_node) -> list[MX | SX]:
"""
This function defines the decision variables, either MX or SX,
scaled to the physical world, they mean something according to the physical model considered.
Expand Down Expand Up @@ -259,7 +259,7 @@ def define_cx_scaled(self, n_col: int, n_shooting: int, initial_node) -> list[MX
)
return _cx

def define_cx_unscaled(self, _cx_scaled: list[MX | SX, ...], scaling: np.ndarray) -> list[MX | SX, ...]:
def define_cx_unscaled(self, _cx_scaled: list[MX | SX], scaling: np.ndarray) -> list[MX | SX]:
"""
This function defines the decision variables, either MX or SX,
unscaled means here the decision variable doesn't correspond to physical quantity.
Expand All @@ -269,7 +269,7 @@ def define_cx_unscaled(self, _cx_scaled: list[MX | SX, ...], scaling: np.ndarray
Parameters
---------
_cx_scaled: list[MX | SX, ...]
_cx_scaled: list[MX | SX]
Decision variables scaled to the physical world
scaling: np.ndarray
The scaling factors associated to the decision variable
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@


def custom_multinode_constraint(
controllers: list[PenaltyController, ...], coef: float, states_mapping: BiMapping = None
controllers: list[PenaltyController], coef: float, states_mapping: BiMapping = None
) -> MX:
"""
The constraint of the transition. The values from the end of the phase to the next are multiplied by coef to
Expand All @@ -42,7 +42,7 @@ def custom_multinode_constraint(
Parameters
----------
controllers: list[PenaltyController, ...]
controllers: list[PenaltyController]
All the controller for the penalties
coef: float
The coefficient of the phase transition (makes no physical sens)
Expand Down
2 changes: 1 addition & 1 deletion bioptim/gui/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -778,7 +778,7 @@ def _compute_y_from_plot_func(
The custom plot to compute
phase_idx: int
The index of the current phase
time_stepwise: list[list[DM], ...]
time_stepwise: list[list[DM]]
The time vector of each phase
dt
The delta times of the current phase
Expand Down
4 changes: 2 additions & 2 deletions bioptim/limits/constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def add_or_replace_to_penalty_pool(self, ocp, nlp):
elif self.bounds.shape[0] != len(self.rows):
raise RuntimeError(f"bounds rows is {self.bounds.shape[0]} but should be {self.rows} or empty")

def _add_penalty_to_pool(self, controller: list[PenaltyController, ...]):
def _add_penalty_to_pool(self, controller: list[PenaltyController]):
controller = controller[0] # This is a special case of Node.TRANSITION

if self.penalty_type == PenaltyType.INTERNAL:
Expand Down Expand Up @@ -1222,7 +1222,7 @@ def add_or_replace_to_penalty_pool(self, ocp, nlp):
elif self.bounds.shape[0] != len(self.rows):
raise RuntimeError(f"bounds rows is {self.bounds.shape[0]} but should be {self.rows} or empty")

def _add_penalty_to_pool(self, controller: list[PenaltyController, ...]):
def _add_penalty_to_pool(self, controller: list[PenaltyController]):
controller = controller[0] # This is a special case of Node.TRANSITION

if self.penalty_type == PenaltyType.INTERNAL:
Expand Down
24 changes: 12 additions & 12 deletions bioptim/limits/multinode_penalty.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def __init__(
def _get_pool_to_add_penalty(self, ocp, nlp):
raise NotImplementedError("This is an abstract method and should be implemented by child")

def _add_penalty_to_pool(self, controller: list[PenaltyController, ...]):
def _add_penalty_to_pool(self, controller: list[PenaltyController]):

controller = controller[0] # This is a special case of Node.TRANSITION

Expand Down Expand Up @@ -124,9 +124,9 @@ class Functions:
@staticmethod
def states_equality(
penalty,
controllers: list[PenaltyController, ...],
controllers: list[PenaltyController],
key: str = "all",
states_mapping: list[BiMapping, ...] = None,
states_mapping: list[BiMapping] = None,
):
"""
The most common continuity function, that is state before equals state after
Expand Down Expand Up @@ -171,15 +171,15 @@ def states_equality(
return out

@staticmethod
def controls_equality(penalty, controllers: list[PenaltyController, ...], key: str = "all"):
def controls_equality(penalty, controllers: list[PenaltyController], key: str = "all"):
"""
The controls before equals controls after
Parameters
----------
penalty : MultinodePenalty
A reference to the penalty
controllers: list[PenaltyController, ...]
controllers: list[PenaltyController]
The penalty node elements
Returns
Expand Down Expand Up @@ -210,7 +210,7 @@ def controls_equality(penalty, controllers: list[PenaltyController, ...], key: s
@staticmethod
def algebraic_states_equality(
penalty,
controllers: list[PenaltyController, ...],
controllers: list[PenaltyController],
key: str = "all",
):
"""
Expand Down Expand Up @@ -249,15 +249,15 @@ def algebraic_states_equality(
return out

@staticmethod
def com_equality(penalty, controllers: list[PenaltyController, ...]):
def com_equality(penalty, controllers: list[PenaltyController]):
"""
The centers of mass are equals for the specified phases and the specified nodes
Parameters
----------
penalty : MultinodePenalty
A reference to the penalty
controllers: list[PenaltyController, ...]
controllers: list[PenaltyController]
The penalty node elements
Returns
Expand All @@ -277,15 +277,15 @@ def com_equality(penalty, controllers: list[PenaltyController, ...]):
return out

@staticmethod
def com_velocity_equality(penalty, controllers: list[PenaltyController, ...]):
def com_velocity_equality(penalty, controllers: list[PenaltyController]):
"""
The centers of mass velocity are equals for the specified phases and the specified nodes
Parameters
----------
penalty : MultinodePenalty
A reference to the penalty
controllers: list[PenaltyController, ...]
controllers: list[PenaltyController]
The penalty node elements
Returns
Expand Down Expand Up @@ -651,7 +651,7 @@ def custom(penalty, controllers: list[PenaltyController, PenaltyController], **e
return penalty.custom_function(controllers, **extra_parameters)

@staticmethod
def _prepare_controller_cx(penalty, controllers: list[PenaltyController, ...]):
def _prepare_controller_cx(penalty, controllers: list[PenaltyController]):
"""
This calls the _compute_controller_cx function for each of the controller then dispatch the cx appropriately
to the controllers
Expand All @@ -667,7 +667,7 @@ def _prepare_controller_cx(penalty, controllers: list[PenaltyController, ...]):
c.cx_index_to_get = index

@staticmethod
def _prepare_states_mapping(controllers: list[PenaltyController, ...], states_mapping: list[BiMapping, ...]):
def _prepare_states_mapping(controllers: list[PenaltyController], states_mapping: list[BiMapping]):
"""
Prepare a new state_mappings if None is sent. Otherwise, it simply returns the current states_mapping
Expand Down
6 changes: 3 additions & 3 deletions bioptim/limits/penalty_option.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,15 +67,15 @@ class PenaltyOption(OptionGeneric):
_check_target_dimensions(self, controller: PenaltyController, n_frames: int)
Checks if the variable index is consistent with the requested variable.
If the function returns, all is okay
_set_penalty_function(self, controller: list[PenaltyController, ...], fcn: MX | SX)
_set_penalty_function(self, controller: list[PenaltyController], fcn: MX | SX)
Finalize the preparation of the penalty (setting function and weighted_function)
add_target_to_plot(self, controller: PenaltyController, combine_to: str)
Interface to the plot so it can be properly added to the proper plot
_finish_add_target_to_plot(self, controller: PenaltyController)
Internal interface to add (after having check the target dimensions) the target to the plot if needed
add_or_replace_to_penalty_pool(self, ocp, nlp)
Doing some configuration on the penalty and add it to the list of penalty
_add_penalty_to_pool(self, controller: list[PenaltyController, ...])
_add_penalty_to_pool(self, controller: list[PenaltyController])
Return the penalty pool for the specified penalty (abstract)
ensure_penalty_sanity(self, ocp, nlp)
Resets a penalty. A negative penalty index creates a new empty penalty (abstract)
Expand Down Expand Up @@ -332,7 +332,7 @@ def transform_penalty_to_stochastic(self, controller: PenaltyController, fcn, st

return diag(fcn_variation)

def _set_phase_dynamics(self, controllers: list[PenaltyController, ...]):
def _set_phase_dynamics(self, controllers: list[PenaltyController]):
phase_dynamics = [c.get_nlp.phase_dynamics for c in controllers]
if self.phase_dynamics:
# If it was already set (e.g. for multinode), we want to make sure it is consistent
Expand Down
4 changes: 2 additions & 2 deletions bioptim/limits/phase_transition.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ class Functions:
def continuous(
transition,
controllers: list[PenaltyController, PenaltyController],
states_mapping: list[BiMapping, ...] = None,
states_mapping: list[BiMapping] = None,
):
"""
The most common continuity function, that is state before equals state after
Expand Down Expand Up @@ -173,7 +173,7 @@ def continuous(
def continuous_controls(
transition,
controllers: list[PenaltyController, PenaltyController],
controls_mapping: list[BiMapping, ...] = None,
controls_mapping: list[BiMapping] = None,
):
"""
This continuity function is only relevant for ControlType.LINEAR_CONTINUOUS otherwise don't use it.
Expand Down
2 changes: 1 addition & 1 deletion bioptim/models/biorbd/multi_biorbd_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -725,7 +725,7 @@ def ligament_joint_torque(self, q, qdot) -> MX:
def ranges_from_model(self, variable: str):
return [the_range for model in self.models for the_range in model.ranges_from_model(variable)]

def bounds_from_ranges(self, variables: str | list[str, ...], mapping: BiMapping | BiMappingList = None) -> Bounds:
def bounds_from_ranges(self, variables: str | list[str], mapping: BiMapping | BiMappingList = None) -> Bounds:
return bounds_from_ranges(self, variables, mapping)

def _var_mapping(self, key: str, range_for_mapping: int | list | tuple | range, mapping: BiMapping = None) -> dict:
Expand Down
10 changes: 3 additions & 7 deletions bioptim/models/protocols/biomodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,7 @@ def passive_joint_torque(self, q, qdot) -> MX:
def ligament_joint_torque(self, q, qdot) -> MX:
"""Get the ligament joint torque"""

def bounds_from_ranges(self, variables: str | list[str, ...], mapping: BiMapping | BiMappingList = None) -> Bounds:
def bounds_from_ranges(self, variables: str | list[str], mapping: BiMapping | BiMappingList = None) -> Bounds:
"""
Create bounds from ranges of the model depending on the variable chosen, such as q, qdot, qddot
Expand Down Expand Up @@ -325,11 +325,7 @@ def partitioned_forward_dynamics(

@staticmethod
def animate(
ocp,
solution: "SolutionData",
show_now: bool = True,
tracked_markers: list[np.ndarray, ...] = None,
**kwargs: Any
ocp, solution: "SolutionData", show_now: bool = True, tracked_markers: list[np.ndarray] = None, **kwargs: Any
) -> None | list:
"""
Animate a solution
Expand All @@ -340,7 +336,7 @@ def animate(
The solution to animate
show_now: bool
If the animation should be shown immediately or not
tracked_markers: list[np.ndarray, ...]
tracked_markers: list[np.ndarray]
The tracked markers (3, n_markers, n_frames)
kwargs: dict
The options to pass to the animator
Expand Down
2 changes: 1 addition & 1 deletion bioptim/models/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def bounds_from_ranges(model, key: str, mapping: BiMapping | BiMappingList = Non
----------
model: bio_model
such as BiorbdModel or MultiBiorbdModel
key: str | list[str, ...]
key: str | list[str]
The variables to generate the bounds from, such as "q", "qdot", "qddot", or ["q", "qdot"],
mapping: BiMapping | BiMappingList
The mapping to use to generate the bounds. If None, the default mapping is built
Expand Down
10 changes: 5 additions & 5 deletions bioptim/optimization/optimization_variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def __init__(
name: str,
mx: MX,
cx_start: list | None,
index: [range, list],
index: range | list,
mapping: BiMapping = None,
parent_list=None,
):
Expand All @@ -50,15 +50,15 @@ def __init__(
The name of the variable
mx: MX
The MX variable associated with this variable
index: [range, list]
index: range | list
The indices to find this variable
parent_list: OptimizationVariableList
The list the OptimizationVariable is in
"""
self.name: str = name
self.mx: MX = mx
self.original_cx: list = cx_start
self.index: [range, list] = index
self.index: range | list = index
self.mapping: BiMapping = mapping
self.parent_list: OptimizationVariableList = parent_list

Expand Down Expand Up @@ -494,8 +494,8 @@ def __init__(self, phase_dynamics: PhaseDynamics):
user sets it to something else)
"""
self.cx_constructor = None
self._unscaled: list[OptimizationVariableList, ...] = []
self._scaled: list[OptimizationVariableList, ...] = []
self._unscaled: list[OptimizationVariableList] = []
self._scaled: list[OptimizationVariableList] = []
self._node_index = 0 # TODO: [0] to [node_index]
self.phase_dynamics = phase_dynamics

Expand Down
18 changes: 9 additions & 9 deletions bioptim/optimization/solution/solution.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,7 +373,7 @@ def decision_time(
Parameters
----------
to_merge: SolutionMerge | list[SolutionMerge, ...]
to_merge: SolutionMerge | list[SolutionMerge]
The type of merge to perform. If None, then no merge is performed. It is often useful to merge NODES, but
is completely useless to merge KEYS
time_alignment: TimeAlignment
Expand Down Expand Up @@ -404,7 +404,7 @@ def stepwise_time(
Parameters
----------
to_merge: SolutionMerge | list[SolutionMerge, ...]
to_merge: SolutionMerge | list[SolutionMerge]
The type of merge to perform. If None, then no merge is performed. It is often useful to merge NODES, but
is completely useless to merge KEYS
time_alignment: TimeAlignment
Expand Down Expand Up @@ -533,7 +533,7 @@ def decision_states(self, scaled: bool = False, to_merge: SolutionMerge | list[S
scaled: bool
If the decision states should be scaled or not (note that scaled is as Ipopt received them, while unscaled
is as the model needs temps). If you don't know what it means, you probably want the unscaled version.
to_merge: SolutionMerge | list[SolutionMerge, ...]
to_merge: SolutionMerge | list[SolutionMerge]
The type of merge to perform. If None, then no merge is performed.
Returns
Expand All @@ -555,7 +555,7 @@ def stepwise_states(self, scaled: bool = False, to_merge: SolutionMerge | list[S
scaled: bool
If the states should be scaled or not (note that scaled is as Ipopt received them, while unscaled is as the
model needs temps). If you don't know what it means, you probably want the unscaled version.
to_merge: SolutionMerge | list[SolutionMerge, ...]
to_merge: SolutionMerge | list[SolutionMerge]
The type of merge to perform. If None, then no merge is performed.
Returns
Expand All @@ -580,7 +580,7 @@ def decision_controls(self, scaled: bool = False, to_merge: SolutionMerge | list
scaled : bool
If the decision controls should be scaled or not (note that scaled is as Ipopt received them, while unscaled
is as the model needs temps). If you don't know what it means, you probably want the unscaled version.
to_merge : SolutionMerge | list[SolutionMerge, ...]
to_merge : SolutionMerge | list[SolutionMerge]
The type of merge to perform. If None, then no merge is performed.
"""
return self.stepwise_controls(scaled=scaled, to_merge=to_merge)
Expand All @@ -594,7 +594,7 @@ def stepwise_controls(self, scaled: bool = False, to_merge: SolutionMerge | list
scaled: bool
If the controls should be scaled or not (note that scaled is as Ipopt received them, while unscaled is as
the model needs temps). If you don't know what it means, you probably want the unscaled version.
to_merge: SolutionMerge | list[SolutionMerge, ...]
to_merge: SolutionMerge | list[SolutionMerge]
The type of merge to perform. If None, then no merge is performed.
Returns
Expand Down Expand Up @@ -660,7 +660,7 @@ def decision_algebraic_states(self, scaled: bool = False, to_merge: SolutionMerg
scaled: bool
If the decision states should be scaled or not (note that scaled is as Ipopt received them, while unscaled
is as the model needs temps). If you don't know what it means, you probably want the unscaled version.
to_merge: SolutionMerge | list[SolutionMerge, ...]
to_merge: SolutionMerge | list[SolutionMerge]
The type of merge to perform. If None, then no merge is performed.
Returns
Expand Down Expand Up @@ -769,7 +769,7 @@ def integrate(
The integration shooting type to use
integrator: SolutionIntegrator
The type of integrator to use
to_merge: SolutionMerge | list[SolutionMerge, ...]
to_merge: SolutionMerge | list[SolutionMerge]
The type of merge to perform. If None, then no merge is performed.
duplicated_times: bool
If the times should be duplicated for each node.
Expand Down Expand Up @@ -1096,7 +1096,7 @@ def _return_time_vector(self, to_merge: SolutionMerge | list[SolutionMerge], dup
Returns the time vector at each node that matches stepwise_states or stepwise_controls
Parameters
----------
to_merge: SolutionMerge | list[SolutionMerge, ...]
to_merge: SolutionMerge | list[SolutionMerge]
The merge type to perform. If None, then no merge is performed.
duplicated_times: bool
If the times should be duplicated for each node.
Expand Down

0 comments on commit 50109c0

Please sign in to comment.