From ac0fc149b697f6a6b7f81c6df9319bbf8460c3c2 Mon Sep 17 00:00:00 2001 From: daniel-dudt Date: Wed, 7 Feb 2024 12:08:39 -0700 Subject: [PATCH] better checking for class types based on attrs --- desc/objectives/getters.py | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/desc/objectives/getters.py b/desc/objectives/getters.py index 2ed963b367..837e4a2c38 100644 --- a/desc/objectives/getters.py +++ b/desc/objectives/getters.py @@ -223,20 +223,26 @@ def _is_any_instance(things, cls): return any([isinstance(t, cls) for t in things]) # Equilibrium - if hasattr(thing, "_Rb_lmn") and hasattr(thing, "_Zb_lmn"): + if ( + hasattr(thing, "Ra_n") + and hasattr(thing, "Za_n") + and hasattr(thing, "Rb_lmn") + and hasattr(thing, "Zb_lmn") + and hasattr(thing, "L_lmn") + ): + if not _is_any_instance(constraints, AxisRSelfConsistency): + constraints += (AxisRSelfConsistency(eq=thing),) + if not _is_any_instance(constraints, AxisZSelfConsistency): + constraints += (AxisZSelfConsistency(eq=thing),) if not _is_any_instance(constraints, BoundaryRSelfConsistency): constraints += (BoundaryRSelfConsistency(eq=thing),) if not _is_any_instance(constraints, BoundaryZSelfConsistency): constraints += (BoundaryZSelfConsistency(eq=thing),) if not _is_any_instance(constraints, FixLambdaGauge): constraints += (FixLambdaGauge(eq=thing),) - if not _is_any_instance(constraints, AxisRSelfConsistency): - constraints += (AxisRSelfConsistency(eq=thing),) - if not _is_any_instance(constraints, AxisZSelfConsistency): - constraints += (AxisZSelfConsistency(eq=thing),) # Curve - elif hasattr(thing, "_shift") and hasattr(thing, "_rotmat"): + elif hasattr(thing, "shift") and hasattr(thing, "rotmat"): if not _is_any_instance(constraints, FixCurveShift): constraints += (FixCurveShift(curve=thing),) if not _is_any_instance(constraints, FixCurveRotation):