Skip to content

Commit

Permalink
[BugFix] smarter check in set_interaction_type
Browse files Browse the repository at this point in the history
ghstack-source-id: 1821309ad24827c22c40c41f3544e7a768325f72
Pull Request resolved: #1088

(cherry picked from commit db2b5e6)
  • Loading branch information
vmoens committed Nov 14, 2024
1 parent b26bbe3 commit 0b3c778
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 4 deletions.
7 changes: 5 additions & 2 deletions tensordict/nn/probabilistic.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,11 @@ def __init__(
self, type: InteractionType | str | None = InteractionType.DETERMINISTIC
) -> None:
super().__init__()
if isinstance(type, str):
type = InteractionType(type.lower())
if not isinstance(type, InteractionType) and type is not None:
if isinstance(type, str):
type = InteractionType(type.lower())
else:
raise ValueError(f"{type} is not a valid InteractionType")
self.type = type

def clone(self) -> set_interaction_type:
Expand Down
3 changes: 1 addition & 2 deletions test/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,9 +88,8 @@ def test_from_str_correct_conversion(self, str_and_expected_type):

@pytest.mark.parametrize("unsupported_type_str", ["foo"])
def test_from_str_correct_raise(self, unsupported_type_str):
with pytest.raises(ValueError) as err:
with pytest.raises(ValueError, match=" is not a valid InteractionType"):
InteractionType.from_str(unsupported_type_str)
assert unsupported_type_str in str(err) and "is unsupported" in str(err)


class TestTDModule:
Expand Down

0 comments on commit 0b3c778

Please sign in to comment.