diff --git a/news/214.feature b/news/214.feature new file mode 100644 index 000000000..38bfa08ed --- /dev/null +++ b/news/214.feature @@ -0,0 +1 @@ +New pydevd resolver plugin for easier debugging diff --git a/pydevd_plugins/__init__.py b/pydevd_plugins/__init__.py new file mode 100644 index 000000000..3a973c9d5 --- /dev/null +++ b/pydevd_plugins/__init__.py @@ -0,0 +1,6 @@ +try: + __import__("pkg_resources").declare_namespace(__name__) +except ImportError: # pragma: no cover + import pkgutil + + __path__ = pkgutil.extend_path(__path__, __name__) # type: ignore diff --git a/pydevd_plugins/extensions/__init__.py b/pydevd_plugins/extensions/__init__.py new file mode 100644 index 000000000..3a973c9d5 --- /dev/null +++ b/pydevd_plugins/extensions/__init__.py @@ -0,0 +1,6 @@ +try: + __import__("pkg_resources").declare_namespace(__name__) +except ImportError: # pragma: no cover + import pkgutil + + __path__ = pkgutil.extend_path(__path__, __name__) # type: ignore diff --git a/pydevd_plugins/extensions/pydevd_plugin_omegaconf.py b/pydevd_plugins/extensions/pydevd_plugin_omegaconf.py new file mode 100644 index 000000000..ea16c6f36 --- /dev/null +++ b/pydevd_plugins/extensions/pydevd_plugin_omegaconf.py @@ -0,0 +1,126 @@ +# based on https://github.com/fabioz/PyDev.Debugger/tree/main/pydevd_plugins/extensions + +import sys +from functools import lru_cache +from typing import Any, Dict + +from _pydevd_bundle.pydevd_extension_api import TypeResolveProvider # type: ignore + +from omegaconf._utils import type_str + + +@lru_cache(maxsize=128) +def find_mod_attr(mod_name: str, attr: str) -> Any: + mod = sys.modules.get(mod_name) + return getattr(mod, attr, None) + + +class Wrapper(object): + def __init__(self, target: Any, desc: str) -> None: + self.target = target + self.desc = desc + + def __repr__(self) -> str: # pragma: no cover + return self.desc + + def __getattr__(self, attr: str) -> Any: # pragma: no cover + return getattr(self.target, attr) + + def __eq__(self, other: Any) -> Any: # pragma: no cover + if isinstance(other, Wrapper): + return self.desc == other.desc and self.target == other.target + else: + return NotImplemented + + +class OmegaConfNodeResolver(object): + def can_provide(self, type_object: Any, type_name: str) -> bool: + Node = find_mod_attr("omegaconf", "Node") + + return Node is not None and issubclass(type_object, (Node, Wrapper)) + + def resolve(self, obj: Any, attribute: str) -> Any: + Node = find_mod_attr("omegaconf", "Node") + DictConfig = find_mod_attr("omegaconf", "DictConfig") + ListConfig = find_mod_attr("omegaconf", "ListConfig") + ValueNode = find_mod_attr("omegaconf", "ValueNode") + + if isinstance(obj, Wrapper): + obj = obj.target + + if attribute == "->" and isinstance(obj, Node): + field = obj._dereference_node(throw_on_resolution_failure=False) + elif isinstance(obj, DictConfig): + field = obj.__dict__["_content"][attribute] + elif isinstance(obj, ListConfig): + field = obj.__dict__["_content"][int(attribute)] + else: # pragma: no cover + assert False + + if isinstance(field, Node) and field._is_interpolation(): + resolved = field._dereference_node(throw_on_resolution_failure=False) + if resolved is not None: + if isinstance(resolved, ValueNode): + resolved_type = type_str(type(resolved._val)) + else: + resolved_type = type_str(type(resolved)) + desc = f"{field} -> {{ {resolved_type} }} {resolved}" + field = Wrapper(field, desc) + + return field + + def get_dictionary(self, obj: Any) -> Dict[str, Any]: + ListConfig = find_mod_attr("omegaconf", "ListConfig") + DictConfig = find_mod_attr("omegaconf", "DictConfig") + Node = find_mod_attr("omegaconf", "Node") + ValueNode = find_mod_attr("omegaconf", "ValueNode") + + if isinstance(obj, Wrapper): + obj = obj.target + + assert isinstance(obj, Node) + + d = {} + + if isinstance(obj, Node): + if obj._is_missing() or obj._is_none(): + return {} + if obj._is_interpolation(): + d["interpolation"] = obj._value() + if obj._parent is not None: + resolved = obj._dereference_node(throw_on_resolution_failure=False) + else: + resolved = None + d["->"] = resolved + return d + else: + if isinstance(obj, ValueNode): + d["_val"] = obj._value() + + if isinstance(obj, ListConfig): + assert not obj._is_interpolation() + assert not obj._is_none() + assert not obj._is_missing() + for idx, node in enumerate(obj.__dict__["_content"]): + d[str(idx)] = node + elif isinstance(obj, DictConfig): + assert not obj._is_interpolation() + assert not obj._is_none() + assert not obj._is_missing() + for key in obj.keys(): + node = obj._get_node(key, throw_on_missing_value=False) + is_inter = node._is_interpolation() + if is_inter: + resolved = node._dereference_node(throw_on_resolution_failure=False) + if resolved is not None: + value = resolved + else: + value = node + else: + value = node._value() + + d[key] = value + return d + + +TypeResolveProvider.register(OmegaConfNodeResolver) diff --git a/requirements/dev.txt b/requirements/dev.txt index d1b6b208b..c18b9b73a 100644 --- a/requirements/dev.txt +++ b/requirements/dev.txt @@ -14,3 +14,4 @@ pytest-mock sphinx towncrier twine +pydevd \ No newline at end of file diff --git a/tests/test_pydev_resolver_plugin.py b/tests/test_pydev_resolver_plugin.py new file mode 100644 index 000000000..9b6003055 --- /dev/null +++ b/tests/test_pydev_resolver_plugin.py @@ -0,0 +1,257 @@ +import builtins +from typing import Any + +from pytest import fixture, mark, param + +from omegaconf import ( + AnyNode, + BooleanNode, + Container, + DictConfig, + EnumNode, + FloatNode, + IntegerNode, + ListConfig, + Node, + OmegaConf, + StringNode, + ValueNode, +) +from omegaconf._utils import type_str +from pydevd_plugins.extensions.pydevd_plugin_omegaconf import ( + OmegaConfNodeResolver, + Wrapper, +) +from tests import Color + + +@fixture +def resolver() -> Any: + yield OmegaConfNodeResolver() + + +@mark.parametrize( + ("obj", "expected"), + [ + # nodes + param(AnyNode(10), {"_val": 10}, id="any:10"), + param(StringNode("foo"), {"_val": "foo"}, id="str:foo"), + param(IntegerNode(10), {"_val": 10}, id="int:10"), + param(FloatNode(3.14), {"_val": 3.14}, id="float:3.14"), + param(BooleanNode(True), {"_val": True}, id="bool:True"), + param( + EnumNode(enum_type=Color, value=Color.RED), + {"_val": Color.RED}, + id="Color:Color.RED", + ), + param(AnyNode("${foo}"), {"interpolation": "${foo}", "->": None}, id="any:10"), + param( + AnyNode("${foo}", parent=OmegaConf.create({"foo": 10})), + {"interpolation": "${foo}", "->": AnyNode(10)}, + id="any:10", + ), + # DictConfig + param(DictConfig({"a": 10}), {"a": AnyNode(10)}, id="dict"), + param( + DictConfig({"a": 10, "b": "${a}"}), + {"a": AnyNode(10), "b": AnyNode(10)}, + id="dict:interpolation_value", + ), + param( + DictConfig({"a": 10, "b": "${zzz}"}), + {"a": AnyNode(10), "b": AnyNode("${zzz}")}, + id="dict:interpolation_value_error", + ), + param( + DictConfig({"a": 10, "b": "foo_${a}"}), + {"a": AnyNode(10), "b": AnyNode("foo_10")}, + id="dict:str_interpolation_value", + ), + # ListConfig + param( + ListConfig(["a", "b"]), {"0": AnyNode("a"), "1": AnyNode("b")}, id="list" + ), + param( + ListConfig(["${1}", 10]), + {"0": AnyNode("${1}"), "1": AnyNode(10)}, + id="list:interpolation_value", + ), + ], +) +def test_get_dictionary_node(resolver: Any, obj: Any, expected: Any) -> None: + res = resolver.get_dictionary(obj) + assert res == expected + + +@mark.parametrize( + ("obj", "attribute", "expected"), + [ + # dictconfig + param(DictConfig({"a": 10}), "a", AnyNode(10), id="dict"), + param( + DictConfig({"a": DictConfig(None)}), + "a", + DictConfig(None), + id="dict:none", + ), + param( + DictConfig({"a": "${b}", "b": 10}), + "a", + Wrapper(AnyNode("${b}"), desc="${b} -> { int } 10"), + id="dict:value_interpolation", + ), + # listconfig + param(ListConfig([10]), 0, AnyNode(10), id="list"), + param(ListConfig(["???"]), 0, AnyNode("???"), id="list"), + param( + ListConfig(["${.1}", 10]), + 0, + Wrapper(AnyNode("${.1}"), desc="${.1} -> { int } 10"), + id="list", + ), + # wrapper + param( + Wrapper(DictConfig({"a": 10}), desc=".."), + "a", + AnyNode(10), + id="dict_in_wrapper", + ), + # dereference + param( + AnyNode("${a}", parent=DictConfig({"a": 10})), + "->", + AnyNode(10), + id="dereference", + ), + ], +) +def test_resolve( + resolver: Any, + obj: Any, + attribute: str, + expected: Any, +) -> None: + res = resolver.resolve(obj, attribute) + assert res == expected + assert type(res) is type(expected) + + +@mark.parametrize( + ("obj", "attribute", "expected"), + [ + param( + OmegaConf.create({"a": 10, "inter": "${a}"}), + "inter", + {"interpolation": "${a}", "->": AnyNode(10)}, + id="dict:inter", + ), + param( + OmegaConf.create({"missing": "???"}), + "missing", + {}, + id="dict:missing_value", + ), + param( + OmegaConf.create({"none": None}), + "none", + {}, + id="dict:none_value", + ), + param( + OmegaConf.create({"none": DictConfig(None)}), + "none", + {}, + id="dict:none_dictconfig_value", + ), + param( + OmegaConf.create({"missing": DictConfig("???")}), + "missing", + {}, + id="dict:missing_dictconfig_value", + ), + param( + OmegaConf.create({"a": {"b": 10}, "b": DictConfig("${a}")}), + "b", + {"interpolation": "${a}", "->": {"b": 10}}, + id="dict:interpolation_dictconfig_value", + ), + ], +) +def test_get_dictionary_dictconfig( + resolver: Any, + obj: Any, + attribute: str, + expected: Any, +) -> None: + field = resolver.resolve(obj, attribute) + res = resolver.get_dictionary(field) + assert res == expected + assert type(res) is type(expected) + + +@mark.parametrize( + ("obj", "attribute", "expected"), + [ + param( + OmegaConf.create(["${.1}", 10]), + "0", + {"interpolation": "${.1}", "->": AnyNode(10)}, + id="list:inter_value", + ), + param( + OmegaConf.create({"a": ListConfig(None)}), + "a", + {}, + id="list:none_listconfig_value", + ), + param( + OmegaConf.create({"a": ListConfig("???")}), + "a", + {}, + id="list:missing_listconfig_value", + ), + param( + OmegaConf.create({"a": [1, 2], "b": ListConfig("${a}")}), + "b", + {"interpolation": "${a}", "->": [1, 2]}, + id="list:interpolationn_listconfig_value", + ), + ], +) +def test_get_dictionary_listconfig( + resolver: Any, + obj: Any, + attribute: str, + expected: Any, +) -> None: + field = resolver.resolve(obj, attribute) + res = resolver.get_dictionary(field) + assert res == expected + assert type(res) is type(expected) + + +@mark.parametrize( + ("type_", "expected"), + [ + # containers + (Container, True), + (DictConfig, True), + (ListConfig, True), + # nodes + (Node, True), + (ValueNode, True), + (AnyNode, True), + (IntegerNode, True), + (FloatNode, True), + (StringNode, True), + (BooleanNode, True), + # internal wrapper + (Wrapper, True), + # not covering some other things. + (builtins.int, False), + (dict, False), + (list, False), + ], +) +def test_can_provide(resolver: Any, type_: Any, expected: bool) -> None: + assert resolver.can_provide(type_, type_str(type_)) == expected