From db2b5e656e46add2bc684c16090b4e97540a6c37 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Thu, 14 Nov 2024 06:34:09 +0000 Subject: [PATCH] [BugFix] smarter check in set_interaction_type ghstack-source-id: 1821309ad24827c22c40c41f3544e7a768325f72 Pull Request resolved: https://github.com/pytorch/tensordict/pull/1088 --- tensordict/nn/probabilistic.py | 7 +++++-- test/test_nn.py | 3 +-- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/tensordict/nn/probabilistic.py b/tensordict/nn/probabilistic.py index 497394a99..72deae8cc 100644 --- a/tensordict/nn/probabilistic.py +++ b/tensordict/nn/probabilistic.py @@ -79,8 +79,11 @@ def __init__( self, type: InteractionType | str | None = InteractionType.DETERMINISTIC ) -> None: super().__init__() - if isinstance(type, str): - type = InteractionType(type.lower()) + if not isinstance(type, InteractionType) and type is not None: + if isinstance(type, str): + type = InteractionType(type.lower()) + else: + raise ValueError(f"{type} is not a valid InteractionType") self.type = type def clone(self) -> set_interaction_type: diff --git a/test/test_nn.py b/test/test_nn.py index e27cb9b9d..274d08b29 100644 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -88,9 +88,8 @@ def test_from_str_correct_conversion(self, str_and_expected_type): @pytest.mark.parametrize("unsupported_type_str", ["foo"]) def test_from_str_correct_raise(self, unsupported_type_str): - with pytest.raises(ValueError) as err: + with pytest.raises(ValueError, match=" is not a valid InteractionType"): InteractionType.from_str(unsupported_type_str) - assert unsupported_type_str in str(err) and "is unsupported" in str(err) class TestTDModule: