Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
vmoens committed Nov 23, 2024
1 parent c01404c commit 6818c8c
Show file tree
Hide file tree
Showing 4 changed files with 23 additions and 9 deletions.
13 changes: 7 additions & 6 deletions tensordict/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
_CloudpickleWrapper,
_DTYPE2STRDTYPE,
_GENERIC_NESTED_ERR,
_is_dataclass as is_dataclass,
_is_non_tensor,
_is_number,
_is_tensorclass,
Expand Down Expand Up @@ -9452,6 +9453,7 @@ def _validate_value(
if device is not None and value.device != device:
if _device_recorder.marked and device.type != "cuda":
_device_recorder.record_transfer(device)
assert not non_blocking
value = value.to(device, non_blocking=non_blocking)
if check_shape:
if is_tc is None:
Expand Down Expand Up @@ -9874,16 +9876,15 @@ def from_any(cls, obj, *, auto_batch_size: bool = False):
return cls.from_dict(obj, auto_batch_size=auto_batch_size)
if isinstance(obj, np.ndarray) and hasattr(obj.dtype, "names"):
return cls.from_struct_array(obj, auto_batch_size=auto_batch_size)
from dataclasses import is_dataclass

if is_dataclass(obj):
return cls.from_dataclass(obj, auto_batch_size=auto_batch_size)
if is_namedtuple(obj):
return cls.from_namedtuple(obj, auto_batch_size=auto_batch_size)
if isinstance(obj, tuple):
return cls.from_tuple(obj, auto_batch_size=auto_batch_size)
if isinstance(obj, list):
return cls.from_tuple(tuple(obj), auto_batch_size=auto_batch_size)
if is_dataclass(obj):
return cls.from_dataclass(obj, auto_batch_size=auto_batch_size)
if is_namedtuple(obj):
return cls.from_namedtuple(obj, auto_batch_size=auto_batch_size)
if _has_h5:
import h5py

Expand Down Expand Up @@ -9942,7 +9943,7 @@ def from_dataclass(
from tensordict.tensorclass import from_dataclass

return from_dataclass(dataclass, auto_batch_size=auto_batch_size)
from dataclasses import fields, is_dataclass
from dataclasses import fields

from tensordict import TensorDict

Expand Down
3 changes: 2 additions & 1 deletion tensordict/tensorclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
CompatibleType,
)
from tensordict.utils import ( # @manual=//pytorch/tensordict:_C
_is_dataclass as is_dataclass,
_is_json_serializable,
_is_tensorclass,
_LOCK_ERROR,
Expand Down Expand Up @@ -450,7 +451,7 @@ def from_dataclass(
by default, this method will return a tensorclass instance or type.
"""
from dataclasses import asdict, is_dataclass, make_dataclass
from dataclasses import asdict, make_dataclass

if isinstance(obj, type):
if is_tensorclass(obj):
Expand Down
11 changes: 11 additions & 0 deletions tensordict/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from collections import defaultdict
from collections.abc import KeysView
from copy import copy
from dataclasses import _FIELDS, GenericAlias
from functools import wraps
from importlib import import_module
from numbers import Number
Expand Down Expand Up @@ -2813,3 +2814,13 @@ def _mismatch_keys(keys1, keys2):
if sub2 is not None:
main.append(sub2)
raise KeyError(r" ".join(main))


def _is_dataclass(obj):
"""Like dataclasses.is_dataclass but compatible with compile."""
cls = (
obj
if isinstance(obj, type) and not isinstance(obj, GenericAlias)
else type(obj)
)
return hasattr(cls, _FIELDS)
5 changes: 3 additions & 2 deletions test/test_tensordict.py
Original file line number Diff line number Diff line change
Expand Up @@ -10661,7 +10661,8 @@ def test_non_tensor_call(self):

def test_nontensor_dict(self, non_tensor_data):
assert (
TensorDict.from_dict(non_tensor_data.to_dict()) == non_tensor_data
TensorDict.from_dict(non_tensor_data.to_dict(), auto_batch_size=True)
== non_tensor_data
).all()

def test_nontensor_tensor(self):
Expand Down Expand Up @@ -11202,7 +11203,7 @@ def _to_float(td, td_name, tmpdir):
td._source = td._source.float()
elif td_name in ("td_h5",):
td = PersistentTensorDict.from_dict(
td.float().to_dict(), filename=tmpdir + "/file.t"
td.float().to_dict(), filename=tmpdir + "/file.t", auto_batch_size=True
)
elif td_name in ("td_params",):
td = TensorDictParams(td.data.float())
Expand Down

0 comments on commit 6818c8c

Please sign in to comment.