diff --git a/test/dynamics/common.py b/test/dynamics/common.py index a27ee060e..92cce526c 100644 --- a/test/dynamics/common.py +++ b/test/dynamics/common.py @@ -163,10 +163,11 @@ def test_array_backends(test_class, backends=None): # reference to module that called this function module = inspect.getmodule(inspect.stack()[1][0]) - libs = ["numpy", "jax", "array_numpy", "array_jax"] - base_classes = [NumpyTestBase, JaxTestBase, ArrayNumpyTestBase, ArrayJaxTestBase] - for lib, base_class in zip(libs, base_classes): - if lib in backends: + classes = inspect.getmembers(inspect.getmodule(inspect.currentframe()), inspect.isclass) + base_classes = [cls[1] for cls in classes if hasattr(cls[1], "lib")] + + for base_class in base_classes: + lib = base_class.lib() class_name = f"{test_class.__name__}_{lib}" setattr(module, class_name, type(class_name, (test_class, base_class), {}))