Skip to content

Commit

Permalink
refactor: check mapping dof (cleaning on the fly)
Browse files Browse the repository at this point in the history
  • Loading branch information
Ipuch committed Nov 20, 2024
1 parent f1c8b2e commit 1993240
Showing 1 changed file with 25 additions and 36 deletions.
61 changes: 25 additions & 36 deletions bioptim/models/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,46 +3,35 @@


def _dof_mapping(key, model, mapping: BiMapping = None) -> dict:
if key == "q":
if model.nb_quaternions > 0 and mapping is not None:
if "q" in mapping and "qdot" not in mapping:
raise RuntimeError(
"It is not possible to provide a q_mapping but not a qdot_mapping if the model have quaternion"
)
return _var_mapping(key, range(model.nb_q), mapping)
elif key == "q_joints":
if model.nb_quaternions > 0 and mapping is not None:
if "q_joints" in mapping and "qdot_joints" not in mapping:
raise RuntimeError(
"It is not possible to provide a q_joints_mapping but not a qdot_joints_mapping if the model have quaternion"
)
return _var_mapping(key, range(model.nb_q - model.nb_root), mapping)
elif key == "q_roots":
return _var_mapping(key, range(model.nb_root), mapping)
elif key == "qdot":
if model.nb_quaternions > 0 and mapping is not None:
if "qdot" in mapping and "q" not in mapping:
raise RuntimeError(
"It is not possible to provide a qdot_mapping but not a q_mapping if the model have quaternion"
)
return _var_mapping(key, range(model.nb_qdot), mapping)
elif key == "qdot_joints":
if model.nb_quaternions > 0 and mapping is not None:
if "qdot_joints" in mapping and "q_joints" not in mapping:
raise RuntimeError(
"It is not possible to provide a qdot_joints_mapping but not a q_joints_mapping if the model have quaternion"
)
return _var_mapping(key, range(model.nb_qdot - model.nb_root), mapping)
elif key == "qdot_roots":
return _var_mapping(key, range(model.nb_root), mapping)
elif key == "qddot":
return _var_mapping(key, range(model.nb_qdot), mapping)
elif key == "qddot_joints":
return _var_mapping(key, range(model.nb_qdot - model.nb_root), mapping)
has_quaternion_and_mapping = model.nb_quaternions > 0 and mapping is not None
if has_quaternion_and_mapping:
_check_quaternion_mapping(key, mapping, model)

ranges_map = {
"q": range(model.nb_q),
"q_joints": range(model.nb_q - model.nb_root),
"q_roots": range(model.nb_root),
"qdot": range(model.nb_qdot),
"qdot_joints": range(model.nb_qdot - model.nb_root),
"qdot_roots": range(model.nb_root),
"qddot": range(model.nb_qdot),
"qddot_joints": range(model.nb_qdot - model.nb_root),
}

if key in ranges_map:
return _var_mapping(key, ranges_map[key], mapping)
else:
raise NotImplementedError("Wrong dof mapping")


def _check_quaternion_mapping(key, mapping, model):
required_mappings = {"q": "qdot", "q_joints": "qdot_joints", "qdot": "q", "qdot_joints": "q_joints"}
if key in mapping and required_mappings[key] not in mapping:
raise RuntimeError(
f"It is not possible to provide a {key}_mapping but not a {required_mappings[key]}_mapping if the model has quaternion"
)


def _var_mapping(key, range_for_mapping, mapping: BiMapping = None) -> dict:
"""
This function returns a standard mapping for the variable key if None.
Expand Down

0 comments on commit 1993240

Please sign in to comment.