Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make is_primitive_list more precise #944

Merged
merged 2 commits into from
May 26, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions omegaconf/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -605,9 +605,7 @@ def is_int(st: str) -> bool:


def is_primitive_list(obj: Any) -> bool:
from .base import Container

return not isinstance(obj, Container) and isinstance(obj, (list, tuple))
return isinstance(obj, (list, tuple))


def is_primitive_dict(obj: Any) -> bool:
Expand Down
2 changes: 1 addition & 1 deletion omegaconf/omegaconf.py
Original file line number Diff line number Diff line change
Expand Up @@ -993,7 +993,7 @@ def _node_wrap(
element_type=element_type,
)
elif (is_list_annotation(ref_type) or is_tuple_annotation(ref_type)) or (
is_primitive_list(value) and ref_type is Any
type(value) in (list, tuple) and ref_type is Any
):
element_type = get_list_element_type(ref_type)
node = ListConfig(
Expand Down
26 changes: 25 additions & 1 deletion tests/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import re
from dataclasses import dataclass, field
from enum import Enum
from typing import Any, Dict, List, Optional, Tuple, Union
from typing import Any, Dict, List, NamedTuple, Optional, Tuple, Union

import attr
from pytest import warns
Expand Down Expand Up @@ -226,6 +226,11 @@ class SubscriptedListOpt:
list_opt: List[Optional[int]] = field(default_factory=lambda: [1, 2, None])


@dataclass
class ListOfAny:
list: List[Any]


@dataclass
class UntypedDict:
dict: Dict = field(default_factory=lambda: {"foo": "var"}) # type: ignore
Expand All @@ -250,6 +255,11 @@ class SubscriptedDictOpt:
)


@dataclass
class DictOfAny:
dict: Dict[Any, Any]


@dataclass
class InterpolationList:
list: List[float] = II("optimization.lr")
Expand All @@ -265,6 +275,20 @@ class Str2Int(Dict[str, int]):
pass


class DictSubclass(Dict[Any, Any]):
pass


class ListSubclass(List[Any]):
pass


class Shape(NamedTuple):
channels: int
height: int
width: int


@dataclass
class OptTuple:
x: Optional[Tuple[int, ...]] = None
Expand Down
57 changes: 55 additions & 2 deletions tests/test_create.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import platform
import re
import sys
from collections.abc import Sequence
from pathlib import Path
from textwrap import dedent
from typing import Any, Dict, List, Optional
Expand All @@ -10,8 +11,18 @@
from pytest import mark, param, raises

from omegaconf import DictConfig, ListConfig, OmegaConf
from omegaconf.errors import UnsupportedValueType
from tests import ConcretePlugin, IllegalType, NonCopyableIllegalType, Plugin
from omegaconf.errors import UnsupportedValueType, ValidationError
from tests import (
ConcretePlugin,
DictOfAny,
DictSubclass,
IllegalType,
ListOfAny,
ListSubclass,
NonCopyableIllegalType,
Plugin,
Shape,
)


@mark.parametrize(
Expand Down Expand Up @@ -112,6 +123,48 @@ def test_create_allow_objects_non_copyable(input_: Any) -> None:
assert cfg == input_


@mark.parametrize(
"input_",
[
param(Shape(10, 2, 3), id="shape"),
param(ListSubclass((1, 2, 3)), id="list_subclass"),
param(DictSubclass({"key": "value"}), id="dict_subclass"),
],
)
class TestCreationWithCustomClass:
def test_top_level(self, input_: Any) -> None:
if isinstance(input_, Sequence):
cfg = OmegaConf.create(input_) # type: ignore
assert isinstance(cfg, ListConfig)
else:
with raises(ValidationError):
OmegaConf.create(input_)

def test_nested(self, input_: Any) -> None:
with raises(UnsupportedValueType):
OmegaConf.create({"foo": input_})

def test_nested_allow_objects(self, input_: Any) -> None:
cfg = OmegaConf.create({"foo": input_}, flags={"allow_objects": True})
assert isinstance(cfg.foo, type(input_))

def test_structured_conf(self, input_: Any) -> None:
if isinstance(input_, Sequence):
cfg = OmegaConf.structured(ListOfAny(input_)) # type: ignore
assert isinstance(cfg.list, ListConfig)
else:
cfg = OmegaConf.structured(DictOfAny(input_))
assert isinstance(cfg.dict, DictConfig)

def test_direct_creation_of_listconfig_or_dictconfig(self, input_: Any) -> None:
if isinstance(input_, Sequence):
cfg = ListConfig(input_) # type: ignore
assert isinstance(cfg, ListConfig)
else:
cfg = DictConfig(input_) # type: ignore
assert isinstance(cfg, DictConfig)


@mark.parametrize(
"input_",
[
Expand Down
53 changes: 53 additions & 0 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
get_list_element_type,
is_dict_annotation,
is_list_annotation,
is_primitive_dict,
is_primitive_list,
is_supported_union_annotation,
is_tuple_annotation,
is_union_annotation,
Expand All @@ -42,8 +44,11 @@
Color,
ConcretePlugin,
Dataframe,
DictSubclass,
IllegalType,
ListSubclass,
Plugin,
Shape,
Str2Int,
UnionAnnotations,
User,
Expand Down Expand Up @@ -662,6 +667,54 @@ def test_type_str_nonetype(type_: Any, expected: str) -> None:
assert _utils.type_str(type_) == expected


@mark.parametrize(
"obj, expected",
[
param([], True, id="list"),
param([1], True, id="list1"),
param((), True, id="tuple"),
param((1,), True, id="tuple1"),
param({}, False, id="dict"),
param(ListSubclass(), True, id="list_subclass"),
param(Shape(10, 2, 3), True, id="namedtuple"),
],
)
def test_is_primitive_list(obj: Any, expected: bool) -> None:
assert is_primitive_list(obj) == expected


@mark.parametrize(
"obj, expected",
[
param({}, True, id="dict"),
param({1: 2}, True, id="dict1"),
param([], False, id="list"),
param((), False, id="tuple"),
],
)
def test_is_primitive_dict(obj: Any, expected: bool) -> None:
assert is_primitive_dict(obj) == expected


@mark.parametrize(
"obj",
[
param(DictConfig({}), id="dictconfig"),
param(ListConfig([]), id="listconfig"),
param(DictSubclass(), id="dict_subclass"),
param(Str2Int(), id="dict_subclass_dataclass"),
param(User, id="user"),
param(User("bond", 7), id="user"),
],
)
class TestIsPrimitiveContainerNegative:
def test_is_primitive_list(self, obj: Any) -> None:
assert not is_primitive_list(obj)

def test_is_primitive_dict(self, obj: Any) -> None:
assert not is_primitive_dict(obj)


@mark.parametrize(
"type_, expected",
[
Expand Down