Skip to content

Commit

Permalink
added positional only arg test
Browse files Browse the repository at this point in the history
  • Loading branch information
omry committed Mar 12, 2021
1 parent 0ff7fc0 commit f9950c4
Show file tree
Hide file tree
Showing 5 changed files with 105 additions and 16 deletions.
10 changes: 3 additions & 7 deletions hydra/_internal/instantiate/_instantiate2.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
import copy
import sys
from enum import Enum
from typing import Any, Callable, Tuple, Union
from typing import Any, Callable, Sequence, Tuple, Union

from omegaconf import ListConfig, OmegaConf, SCMode
from omegaconf import OmegaConf, SCMode
from omegaconf._utils import is_structured_config

from hydra._internal.utils import _locate
Expand Down Expand Up @@ -34,11 +34,7 @@ def _extract_pos_args(*input_args: Any, **kwargs: Any) -> Tuple[Any, Any]:
config_args = kwargs.pop(_Keys.ARGS, ())
output_args = config_args

if (
isinstance(config_args, tuple)
or isinstance(config_args, list)
or isinstance(config_args, ListConfig)
):
if isinstance(config_args, Sequence):
if len(input_args) > 0:
output_args = input_args
else:
Expand Down
2 changes: 1 addition & 1 deletion tests/instantiate/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:
self.kwargs = kwargs

def __repr__(self) -> str:
return f"{self.args=},{self.kwargs=}"
return f"self.args={self.args},self.kwarg={self.kwargs}"

def __eq__(self, other: Any) -> Any:
if isinstance(other, ArgsClass):
Expand Down
40 changes: 40 additions & 0 deletions tests/instantiate/positional_only.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved

# Contains Python 3.8 syntax
# flake does not like it when running on older versions of Python
# flake8: noqa
# With mypy, there does not seem to be a way to prevent an error when running in Python < 3.8.
# For this reason, I am including the code as a string (the horror).
# Once we upgrade to mypy 0.812, we should be able to use --exclude and eliminate this hack.
from typing import Any

code = """
class PosOnlyArgsClass:
def __init__(self, a: Any, b: Any, /, **kwargs: Any) -> None:
assert isinstance(kwargs, dict)
self.a = a
self.b = b
self.kwargs = kwargs
def __repr__(self) -> str:
return f"{self.a=},{self.b},{self.kwargs=}"
def __eq__(self, other: Any) -> Any:
if isinstance(other, PosOnlyArgsClass):
return (
self.a == other.a and self.b == other.b and self.kwargs == other.kwargs
)
else:
return NotImplemented
"""


# Dummy class to keep mypy happy
class PosOnlyArgsClass:
def __init__(self, *args: Any, **kwargs: Any) -> None:
...


exec(code) # nosec
18 changes: 10 additions & 8 deletions tests/instantiate/test_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,20 +20,22 @@
@mark.parametrize(
"name,expected",
[
("tests.Adam", Adam),
("tests.Parameters", Parameters),
("tests.AClass", AClass),
("tests.ASubclass", ASubclass),
("tests.NestingClass", NestingClass),
("tests.AnotherClass", AnotherClass),
("tests.instantiate.Adam", Adam),
("tests.instantiate.Parameters", Parameters),
("tests.instantiate.AClass", AClass),
("tests.instantiate.ASubclass", ASubclass),
("tests.instantiate.NestingClass", NestingClass),
("tests.instantiate.AnotherClass", AnotherClass),
("", raises(ImportError, match=re.escape("Empty path"))),
[
"not_found",
raises(ImportError, match=re.escape("Error loading module 'not_found'")),
],
(
"tests.b.c.Door",
raises(ImportError, match=re.escape("No module named 'tests.b'")),
"tests.instantiate.b.c.Door",
raises(
ImportError, match=re.escape("No module named 'tests.instantiate.b'")
),
),
],
)
Expand Down
51 changes: 51 additions & 0 deletions tests/instantiate/test_positional_only_arguments.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import sys
from typing import Any

from pytest import mark, param, skip

from hydra.utils import instantiate

if sys.version_info < (3, 8):
skip(
msg="Positional-only syntax is only supported in Python 3.8 or newer",
allow_module_level=True,
)


from .positional_only import PosOnlyArgsClass


@mark.parametrize(
("cfg", "args", "expected"),
[
param(
{
"_target_": "tests.instantiate.positional_only.PosOnlyArgsClass",
"_args_": [1, 2],
},
[],
PosOnlyArgsClass(1, 2),
id="pos_only_in_config",
),
param(
{
"_target_": "tests.instantiate.positional_only.PosOnlyArgsClass",
},
[1, 2],
PosOnlyArgsClass(1, 2),
id="pos_only_in_override",
),
param(
{
"_target_": "tests.instantiate.positional_only.PosOnlyArgsClass",
"_args_": [1, 2],
},
[3, 4],
PosOnlyArgsClass(3, 4),
id="pos_only_in_both",
),
],
)
def test_positional_only_arguments(cfg: Any, args: Any, expected: Any) -> None:
assert instantiate(cfg, *args) == expected

0 comments on commit f9950c4

Please sign in to comment.