From 41d8197bbc2627aa6742d0358e16ade0a93a4747 Mon Sep 17 00:00:00 2001 From: John Siirola Date: Wed, 21 Feb 2024 16:23:41 -0700 Subject: [PATCH] Support config domains with either method or attribute domain_name --- pyomo/common/config.py | 6 +++++- pyomo/common/tests/test_config.py | 35 +++++++++++++++++++++++++++++++ 2 files changed, 40 insertions(+), 1 deletion(-) diff --git a/pyomo/common/config.py b/pyomo/common/config.py index 238bdd78e9d..f9c3a725bb8 100644 --- a/pyomo/common/config.py +++ b/pyomo/common/config.py @@ -1134,7 +1134,11 @@ def _domain_name(domain): if domain is None: return "" elif hasattr(domain, 'domain_name'): - return domain.domain_name() + dn = domain.domain_name + if hasattr(dn, '__call__'): + return dn() + else: + return dn elif domain.__class__ is type: return domain.__name__ elif inspect.isfunction(domain): diff --git a/pyomo/common/tests/test_config.py b/pyomo/common/tests/test_config.py index 0bbed43423d..12657481764 100644 --- a/pyomo/common/tests/test_config.py +++ b/pyomo/common/tests/test_config.py @@ -3265,6 +3265,41 @@ def __init__( OUT.getvalue().replace('null', 'None'), ) + def test_domain_name(self): + cfg = ConfigDict() + + cfg.declare('none', ConfigValue()) + self.assertEqual(cfg.get('none').domain_name(), '') + + def fcn(val): + return val + + cfg.declare('fcn', ConfigValue(domain=fcn)) + self.assertEqual(cfg.get('fcn').domain_name(), 'fcn') + + fcn.domain_name = 'custom fcn' + self.assertEqual(cfg.get('fcn').domain_name(), 'custom fcn') + + class functor: + def __call__(self, val): + return val + + cfg.declare('functor', ConfigValue(domain=functor())) + self.assertEqual(cfg.get('functor').domain_name(), 'functor') + + class cfunctor: + def __call__(self, val): + return val + + def domain_name(self): + return 'custom functor' + + cfg.declare('cfunctor', ConfigValue(domain=cfunctor())) + self.assertEqual(cfg.get('cfunctor').domain_name(), 'custom functor') + + cfg.declare('type', ConfigValue(domain=int)) + self.assertEqual(cfg.get('type').domain_name(), 'int') + if __name__ == "__main__": unittest.main()