Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
vmoens committed Dec 4, 2024
1 parent 06469d0 commit f563108
Showing 1 changed file with 20 additions and 6 deletions.
26 changes: 20 additions & 6 deletions test/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import argparse
import contextlib
import copy
import os
import pickle
import unittest
import weakref
Expand Down Expand Up @@ -65,6 +66,7 @@
except ImportError:
from tensordict.utils import Buffer


# Capture all warnings
pytestmark = [
pytest.mark.filterwarnings("error"),
Expand All @@ -80,6 +82,18 @@
),
]

PYTORCH_TEST_FBCODE = os.getenv("PYTORCH_TEST_FBCODE")
if PYTORCH_TEST_FBCODE:
pytestmark.append(
pytest.mark.filterwarnings("ignore:aggregate_probabilities"),
)
pytestmark.append(
pytest.mark.filterwarnings("ignore:include_sum"),
)
pytestmark.append(
pytest.mark.filterwarnings("ignore:inplace"),
)


class TestInteractionType:
@pytest.mark.parametrize(
Expand Down Expand Up @@ -1091,17 +1105,17 @@ def test_probtdseq_multdist(self, include_sum, aggregate_probabilities, inplace)

v = tdm(TensorDict(x=torch.randn(10, 3)))
assert set(v.keys()) == {"x", "loc", "y", "loc2", "z"}
if aggregate_probabilities is None:
if aggregate_probabilities is None and not PYTORCH_TEST_FBCODE:
cm0 = pytest.warns(
expected_warning=DeprecationWarning, match="aggregate_probabilities"
)
else:
cm0 = contextlib.nullcontext()
if include_sum is None:
if include_sum is None and not PYTORCH_TEST_FBCODE:
cm1 = pytest.warns(expected_warning=DeprecationWarning, match="include_sum")
else:
cm1 = contextlib.nullcontext()
if inplace is None:
if inplace is None and not PYTORCH_TEST_FBCODE:
cm2 = pytest.warns(expected_warning=DeprecationWarning, match="inplace")
else:
cm2 = contextlib.nullcontext()
Expand Down Expand Up @@ -1150,17 +1164,17 @@ def test_probtdseq_intermediate_dist(

v = tdm(TensorDict(x=torch.randn(10, 3)))
assert set(v.keys()) == {"x", "loc", "y", "loc2"}
if aggregate_probabilities is None:
if aggregate_probabilities is None and not PYTORCH_TEST_FBCODE:
cm0 = pytest.warns(
expected_warning=DeprecationWarning, match="aggregate_probabilities"
)
else:
cm0 = contextlib.nullcontext()
if include_sum is None:
if include_sum is None and not PYTORCH_TEST_FBCODE:
cm1 = pytest.warns(expected_warning=DeprecationWarning, match="include_sum")
else:
cm1 = contextlib.nullcontext()
if inplace is None:
if inplace is None and not PYTORCH_TEST_FBCODE:
cm2 = pytest.warns(expected_warning=DeprecationWarning, match="inplace")
else:
cm2 = contextlib.nullcontext()
Expand Down

0 comments on commit f563108

Please sign in to comment.