Skip to content

Commit

Permalink
[Refactor] Refactor context managers
Browse files Browse the repository at this point in the history
ghstack-source-id: 5ea5b1af825a2b5c8b9b1607d3001e34a36a021c
Pull Request resolved: #1098
  • Loading branch information
vmoens committed Nov 21, 2024
1 parent bbfe8c7 commit 55f6b91
Show file tree
Hide file tree
Showing 5 changed files with 68 additions and 61 deletions.
2 changes: 1 addition & 1 deletion tensordict/_td.py
Original file line number Diff line number Diff line change
Expand Up @@ -609,7 +609,7 @@ def _quick_set(swap_dict, swap_td):
_quick_set(_swap, swap_dest)
return swap_dest
else:
return TensorDict._new_unsafe(_swap, batch_size=[])
return TensorDict._new_unsafe(_swap, batch_size=torch.Size(()))

def __ne__(self, other: object) -> T | bool:
if is_tensorclass(other):
Expand Down
14 changes: 6 additions & 8 deletions tensordict/nn/probabilistic.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from tensordict.nn.utils import _set_skip_existing_None
from tensordict.tensorclass import is_non_tensor
from tensordict.tensordict import TensorDictBase
from tensordict.utils import _zip_strict
from tensordict.utils import _ContextManager, _zip_strict
from torch import distributions as D, Tensor

from torch.utils._contextlib import _DecoratorContextManager
Expand Down Expand Up @@ -66,12 +66,12 @@ def from_str(cls, type_str: str) -> InteractionType:
return cls(type_str.lower())


_INTERACTION_TYPE: InteractionType | None = None
_interaction_type = _ContextManager()


def interaction_type() -> InteractionType | None:
"""Returns the current sampling type."""
return _INTERACTION_TYPE
return _interaction_type.get_mode()


class set_interaction_type(_DecoratorContextManager):
Expand All @@ -98,13 +98,11 @@ def clone(self) -> set_interaction_type:
return type(self)(self.type)

def __enter__(self) -> None:
global _INTERACTION_TYPE
self.prev = _INTERACTION_TYPE
_INTERACTION_TYPE = self.type
self.prev = _interaction_type.get_mode()
_interaction_type.set_mode(self.type)

def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
global _INTERACTION_TYPE
_INTERACTION_TYPE = self.prev
_interaction_type.set_mode(self.prev)


class ProbabilisticTensorDictModule(TensorDictModuleBase):
Expand Down
39 changes: 18 additions & 21 deletions tensordict/nn/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,22 +12,24 @@
from typing import Any, Callable

import torch
from tensordict.utils import strtobool
from tensordict.utils import _ContextManager, strtobool
from torch import nn

from torch.utils._contextlib import _DecoratorContextManager

try:
from torch.compiler import is_dynamo_compiling
except ImportError: # torch 2.0
from torch._dynamo import is_compiling as is_dynamo_compiling


DISPATCH_TDNN_MODULES = strtobool(os.environ.get("DISPATCH_TDNN_MODULES", "True"))
_dispatch_tdnn_modules = _ContextManager(
default=strtobool(os.environ.get("DISPATCH_TDNN_MODULES", "True"))
)

__all__ = ["mappings", "inv_softplus", "biased_softplus"]

_SKIP_EXISTING = False

from torch.utils._contextlib import _DecoratorContextManager
_skip_existing = _ContextManager(default=False)


def inv_softplus(bias: float | torch.Tensor) -> float | torch.Tensor:
Expand Down Expand Up @@ -300,19 +302,17 @@ def wrapper(_self, tensordict, *args: Any, **kwargs: Any) -> Any:
def __enter__(self) -> None:
if self.mode and is_dynamo_compiling():
raise RuntimeError("skip_existing is not compatible with TorchDynamo.")
global _SKIP_EXISTING
self.prev = _SKIP_EXISTING
self.prev = _skip_existing.get_mode()
if self.mode is not None:
_SKIP_EXISTING = self.mode
_skip_existing.set_mode(self.mode)
elif not self._called:
raise RuntimeError(
f"It seems you are using {type(self).__name__} as a context manager with ``None`` input. "
f"This behaviour is not allowed."
)

def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
global _SKIP_EXISTING
_SKIP_EXISTING = self.prev
_skip_existing.set_mode(self.prev)


class _set_skip_existing_None(set_skip_existing):
Expand Down Expand Up @@ -353,12 +353,11 @@ def wrapper(_self, tensordict, *args: Any, **kwargs: Any) -> Any:
return tensordict
if is_dynamo_compiling():
return func(_self, tensordict, *args, **kwargs)
global _SKIP_EXISTING
self.prev = _SKIP_EXISTING
self.prev = _skip_existing.get_mode()
try:
result = func(_self, tensordict, *args, **kwargs)
finally:
_SKIP_EXISTING = self.prev
_skip_existing.set_mode(self.prev)
return result

return wrapper
Expand All @@ -375,7 +374,7 @@ def clone(self) -> _set_skip_existing_None:

def skip_existing():
"""Returns whether or not existing entries in a tensordict should be re-computed by a module."""
return _SKIP_EXISTING
return _skip_existing.get_mode()


def _rebuild_buffer(data, requires_grad, backward_hooks):
Expand All @@ -397,7 +396,7 @@ def _rebuild_buffer(data, requires_grad, backward_hooks):

def _dispatch_td_nn_modules():
"""Returns ``True`` if @dispatch should be used. Not using dispatch is faster and also better compatible with torch.compile."""
return DISPATCH_TDNN_MODULES
return _dispatch_tdnn_modules.get_mode()


class _set_dispatch_td_nn_modules(_DecoratorContextManager):
Expand All @@ -411,17 +410,15 @@ def clone(self):
return type(self)(self.mode)

def __enter__(self):
global DISPATCH_TDNN_MODULES
# We want to avoid changing global variables because compile puts guards on them
if DISPATCH_TDNN_MODULES != self.mode:
self._saved_mode = DISPATCH_TDNN_MODULES
DISPATCH_TDNN_MODULES = self.mode
if _dispatch_tdnn_modules.get_mode() != self.mode:
self._saved_mode = _dispatch_tdnn_modules
_dispatch_tdnn_modules.set_mode(self.mode)

def __exit__(self, exc_type, exc_val, exc_tb):
if self._saved_mode is None:
return
global DISPATCH_TDNN_MODULES
DISPATCH_TDNN_MODULES = self._saved_mode
_dispatch_tdnn_modules.set_mode(self._saved_mode)


# Reproduce StrEnum for python<3.11
Expand Down
18 changes: 18 additions & 0 deletions tensordict/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,12 @@
import re

import sys
import threading
import time
import warnings
from collections import defaultdict
from collections.abc import KeysView
from contextlib import nullcontext
from copy import copy
from functools import wraps
from importlib import import_module
Expand Down Expand Up @@ -2813,3 +2815,19 @@ def _mismatch_keys(keys1, keys2):
if sub2 is not None:
main.append(sub2)
raise KeyError(r" ".join(main))


class _ContextManager:
def __init__(self, default=None):
self._mode: Any | None = default
self._lock = threading.Lock()

def get_mode(self) -> Any | None:
cm = self._lock if not is_dynamo_compiling() else nullcontext()
with cm:
return self._mode

def set_mode(self, type: Any | None) -> None:
cm = self._lock if not is_dynamo_compiling() else nullcontext()
with cm:
self._mode = type
56 changes: 25 additions & 31 deletions test/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
from tensordict._C import unravel_key_list
from tensordict.nn import (
dispatch,
probabilistic as nn_probabilistic,
ProbabilisticTensorDictModule,
ProbabilisticTensorDictSequential,
TensorDictModuleBase,
Expand All @@ -31,7 +30,11 @@
)
from tensordict.nn.distributions.composite import CompositeDistribution
from tensordict.nn.ensemble import EnsembleModule
from tensordict.nn.probabilistic import InteractionType, set_interaction_type
from tensordict.nn.probabilistic import (
interaction_type,
InteractionType,
set_interaction_type,
)
from tensordict.nn.utils import (
_set_dispatch_td_nn_modules,
set_skip_existing,
Expand Down Expand Up @@ -299,10 +302,8 @@ class Data:

@pytest.mark.parametrize("out_keys", [["loc", "scale"], ["loc_1", "scale_1"]])
@pytest.mark.parametrize("lazy", [True, False])
@pytest.mark.parametrize(
"interaction_type", [InteractionType.MODE, InteractionType.RANDOM, None]
)
def test_stateful_probabilistic_deprec(self, lazy, interaction_type, out_keys):
@pytest.mark.parametrize("it", [InteractionType.MODE, InteractionType.RANDOM, None])
def test_stateful_probabilistic_deprec(self, lazy, it, out_keys):
torch.manual_seed(0)
param_multiplier = 2
if lazy:
Expand Down Expand Up @@ -332,10 +333,10 @@ def test_stateful_probabilistic_deprec(self, lazy, interaction_type, out_keys):
tensordict_module = ProbabilisticTensorDictSequential(net, prob_module)

td = TensorDict({"in": torch.randn(3, 3)}, [3])
with set_interaction_type(interaction_type):
with set_interaction_type(it):
with (
pytest.warns(UserWarning, match="deterministic_sample")
if interaction_type in (InteractionType.DETERMINISTIC, None)
if it in (InteractionType.DETERMINISTIC, None)
else contextlib.nullcontext()
):
tensordict_module(td)
Expand All @@ -345,12 +346,8 @@ def test_stateful_probabilistic_deprec(self, lazy, interaction_type, out_keys):
@pytest.mark.parametrize("out_keys", [["low"], ["low1"], [("stuff", "low1")]])
@pytest.mark.parametrize("lazy", [True, False])
@pytest.mark.parametrize("max_dist", [1.0, 2.0])
@pytest.mark.parametrize(
"interaction_type", [InteractionType.MODE, InteractionType.RANDOM, None]
)
def test_stateful_probabilistic_kwargs(
self, lazy, interaction_type, out_keys, max_dist
):
@pytest.mark.parametrize("it", [InteractionType.MODE, InteractionType.RANDOM, None])
def test_stateful_probabilistic_kwargs(self, lazy, it, out_keys, max_dist):
torch.manual_seed(0)
if lazy:
net = nn.LazyLinear(4)
Expand All @@ -376,10 +373,10 @@ def test_stateful_probabilistic_kwargs(
tensordict_module = ProbabilisticTensorDictSequential(net, prob_module)

td = TensorDict({"in": torch.randn(3, 3)}, [3])
with set_interaction_type(interaction_type):
with set_interaction_type(it):
with (
pytest.warns(UserWarning, match="deterministic_sample")
if interaction_type in (None, InteractionType.DETERMINISTIC)
if it in (None, InteractionType.DETERMINISTIC)
else contextlib.nullcontext()
):
tensordict_module(td)
Expand Down Expand Up @@ -409,10 +406,8 @@ def test_nontensor(self):
],
)
@pytest.mark.parametrize("lazy", [True, False])
@pytest.mark.parametrize(
"interaction_type", [InteractionType.MODE, InteractionType.RANDOM, None]
)
def test_stateful_probabilistic(self, lazy, interaction_type, out_keys):
@pytest.mark.parametrize("it", [InteractionType.MODE, InteractionType.RANDOM, None])
def test_stateful_probabilistic(self, lazy, it, out_keys):
torch.manual_seed(0)
param_multiplier = 2
if lazy:
Expand Down Expand Up @@ -441,10 +436,10 @@ def test_stateful_probabilistic(self, lazy, interaction_type, out_keys):
)

td = TensorDict({"in": torch.randn(3, 3)}, [3])
with set_interaction_type(interaction_type):
with set_interaction_type(it):
with (
pytest.warns(UserWarning, match="deterministic_sample")
if interaction_type in (None, InteractionType.DETERMINISTIC)
if it in (None, InteractionType.DETERMINISTIC)
else contextlib.nullcontext()
):
tensordict_module(td)
Expand Down Expand Up @@ -1043,18 +1038,16 @@ def test_subsequence_weight_update(self):
assert torch.allclose(td_module[0].module.weight, sub_seq_1[0].module.weight)


@pytest.mark.parametrize(
"interaction_type", [InteractionType.RANDOM, InteractionType.MODE]
)
@pytest.mark.parametrize("it", [InteractionType.RANDOM, InteractionType.MODE])
class TestSIM:
def test_cm(self, interaction_type):
with set_interaction_type(interaction_type):
assert nn_probabilistic._INTERACTION_TYPE == interaction_type
def test_cm(self, it):
with set_interaction_type(it):
assert interaction_type() == it

def test_dec(self, interaction_type):
@set_interaction_type(interaction_type)
def test_dec(self, it):
@set_interaction_type(it)
def dummy():
assert nn_probabilistic._INTERACTION_TYPE == interaction_type
assert interaction_type() == it

dummy()

Expand Down Expand Up @@ -1950,6 +1943,7 @@ class MyModule(nn.Module):
params_m0 = params_m
params_m0 = params_m0.apply(lambda x: x.data * 0)
assert (params_m0 == 0).all()
assert not (params_m == 0).all()
with params_m0.to_module(m):
assert (params_m == 0).all()
assert not (params_m == 0).all()
Expand Down

0 comments on commit 55f6b91

Please sign in to comment.