Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/main' into faster_to
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens committed Oct 9, 2023
2 parents a235463 + a61f7a0 commit ebe4a58
Show file tree
Hide file tree
Showing 7 changed files with 89 additions and 74 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/lint.yml
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ jobs:
echo '::group::Lint C source'
set +e
./.github/unittest/linux/scripts/run-clang-format.py -r torchrl/csrc --clang-format-executable ./clang-format
./.github/unittest/linux/scripts/run-clang-format.py -r tensordict/csrc --clang-format-executable ./clang-format
if [ $? -ne 0 ]; then
git --no-pager diff
Expand Down
3 changes: 3 additions & 0 deletions benchmarks/nn/functional_benchmarks_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,9 @@ def test_vmap_mlp_speed(benchmark, stack, tdmodule):


@torch.no_grad()
@pytest.mark.skipif(
not torch.cuda.device_count(), reason="cuda device required for test"
)
@pytest.mark.parametrize("stack", [True, False])
@pytest.mark.parametrize("tdmodule", [True, False])
def test_vmap_transformer_speed(benchmark, stack, tdmodule):
Expand Down
8 changes: 6 additions & 2 deletions tensordict/csrc/pybind.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,10 @@ PYBIND11_MODULE(_tensordict, m) {
m.def("unravel_keys", &unravel_key, py::arg("key")); // for bc compat
m.def("unravel_key", &unravel_key, py::arg("key"));
m.def("_unravel_key_to_tuple", &_unravel_key_to_tuple, py::arg("key"));
m.def("unravel_key_list", py::overload_cast<const py::list&>(&unravel_key_list), py::arg("keys"));
m.def("unravel_key_list", py::overload_cast<const py::tuple&>(&unravel_key_list), py::arg("keys"));
m.def("unravel_key_list",
py::overload_cast<const py::list &>(&unravel_key_list),
py::arg("keys"));
m.def("unravel_key_list",
py::overload_cast<const py::tuple &>(&unravel_key_list),
py::arg("keys"));
}
109 changes: 54 additions & 55 deletions tensordict/csrc/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,71 +8,70 @@

namespace py = pybind11;

py::tuple _unravel_key_to_tuple(const py::object &key) {
bool is_tuple = py::isinstance<py::tuple>(key);
bool is_str = py::isinstance<py::str>(key);

py::tuple _unravel_key_to_tuple(const py::object& key) {
bool is_tuple = py::isinstance<py::tuple>(key);
bool is_str = py::isinstance<py::str>(key);

if (is_tuple) {
py::list newkey;
for (const auto& subkey : key) {
if (py::isinstance<py::str>(subkey)) {
newkey.append(subkey);
} else {
auto _key = _unravel_key_to_tuple(subkey.cast<py::object>());
if (_key.size() == 0) {
return py::make_tuple();
}
newkey += _key;
}
if (is_tuple) {
py::list newkey;
for (const auto &subkey : key) {
if (py::isinstance<py::str>(subkey)) {
newkey.append(subkey);
} else {
auto _key = _unravel_key_to_tuple(subkey.cast<py::object>());
if (_key.size() == 0) {
return py::make_tuple();
}
return py::tuple(newkey);
}
if (is_str) {
return py::make_tuple(key);
} else {
return py::make_tuple();
newkey += _key;
}
}
return py::tuple(newkey);
}
if (is_str) {
return py::make_tuple(key);
} else {
return py::make_tuple();
}
}

py::object unravel_key(const py::object& key) {
bool is_tuple = py::isinstance<py::tuple>(key);
bool is_str = py::isinstance<py::str>(key);
py::object unravel_key(const py::object &key) {
bool is_tuple = py::isinstance<py::tuple>(key);
bool is_str = py::isinstance<py::str>(key);

if (is_tuple) {
py::list newkey;
int count = 0;
for (const auto& subkey : key) {
if (py::isinstance<py::str>(subkey)) {
newkey.append(subkey);
count++;
} else {
auto _key = _unravel_key_to_tuple(subkey.cast<py::object>());
count += _key.size();
newkey += _key;
}
}
if (count == 1) {
return newkey[0];
}
return py::tuple(newkey);
if (is_tuple) {
py::list newkey;
int count = 0;
for (const auto &subkey : key) {
if (py::isinstance<py::str>(subkey)) {
newkey.append(subkey);
count++;
} else {
auto _key = _unravel_key_to_tuple(subkey.cast<py::object>());
count += _key.size();
newkey += _key;
}
}
if (is_str) {
return key;
} else {
throw std::runtime_error("key should be a Sequence<NestedKey>");
if (count == 1) {
return newkey[0];
}
return py::tuple(newkey);
}
if (is_str) {
return key;
} else {
throw std::runtime_error("key should be a Sequence<NestedKey>");
}
}

py::list unravel_key_list(const py::list& keys) {
py::list newkeys;
for (const auto& key : keys) {
auto _key = unravel_key(key.cast<py::object>());
newkeys.append(_key);
}
return newkeys;
py::list unravel_key_list(const py::list &keys) {
py::list newkeys;
for (const auto &key : keys) {
auto _key = unravel_key(key.cast<py::object>());
newkeys.append(_key);
}
return newkeys;
}

py::list unravel_key_list(const py::tuple& keys) {
return unravel_key_list(py::list(keys));
py::list unravel_key_list(const py::tuple &keys) {
return unravel_key_list(py::list(keys));
}
14 changes: 7 additions & 7 deletions tensordict/tensorclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,10 +185,10 @@ def __torch_function__(

for attr in TensorDict.__dict__.keys():
func = getattr(TensorDict, attr)
if (
inspect.ismethod(func) and func.__self__ is TensorDict
): # detects classmethods
setattr(cls, attr, _wrap_classmethod(cls, func))
if inspect.ismethod(func):
tdcls = func.__self__
if issubclass(tdcls, TensorDictBase): # detects classmethods
setattr(cls, attr, _wrap_classmethod(tdcls, cls, func))

cls.to_tensordict = _to_tensordict
cls.device = property(_device, _device_setter)
Expand Down Expand Up @@ -439,10 +439,10 @@ def wrapped_func(*args, **kwargs):
return wrapped_func


def _wrap_classmethod(cls, func):
def _wrap_classmethod(td_cls, cls, func):
@functools.wraps(func)
def wrapped_func(*args, **kwargs):
res = func.__get__(cls)(*args, **kwargs)
res = func.__get__(td_cls)(*args, **kwargs)
# res = func(*args, **kwargs)
if isinstance(res, TensorDictBase):
# create a new tensorclass from res and copy the metadata from self
Expand Down Expand Up @@ -498,7 +498,7 @@ def _setitem(self, item: NestedKey, value: Any) -> None: # noqa: D417
if isinstance(item, str) or (
isinstance(item, tuple) and all(isinstance(_item, str) for _item in item)
):
raise ValueError("Invalid indexing arguments.")
raise ValueError(f"Invalid indexing arguments: {item}.")

if not is_tensorclass(value) and not isinstance(
value, (TensorDictBase, numbers.Number, Tensor, MemmapTensor)
Expand Down
26 changes: 17 additions & 9 deletions tensordict/tensordict.py
Original file line number Diff line number Diff line change
Expand Up @@ -4733,23 +4733,31 @@ def load_memmap(cls, prefix: str) -> T:
key = key[:-1] # drop "meta.pt" from key
metadata = torch.load(path)
if key in out.keys(include_nested=True):
out[key].batch_size = metadata["batch_size"]
out.get(key).batch_size = metadata["batch_size"]
device = metadata["device"]
if device is not None:
out[key] = out[key].to(device)
out.set(key, out.get(key).to(device))
else:
out[key] = cls(
{}, batch_size=metadata["batch_size"], device=metadata["device"]
out.set(
key,
cls(
{},
batch_size=metadata["batch_size"],
device=metadata["device"],
),
)
else:
leaf, *_ = key[-1].rsplit(".", 2) # remove .meta.pt suffix
key = (*key[:-1], leaf)
metadata = torch.load(path)
out[key] = MemmapTensor(
*metadata["shape"],
device=metadata["device"],
dtype=metadata["dtype"],
filename=str(path.parent / f"{leaf}.memmap"),
out.set(
key,
MemmapTensor(
*metadata["shape"],
device=metadata["device"],
dtype=metadata["dtype"],
filename=str(path.parent / f"{leaf}.memmap"),
),
)

return out
Expand Down
1 change: 1 addition & 0 deletions test/test_tensorclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -1672,6 +1672,7 @@ class MyClass:
a: TensorDictBase

tc = MyClass.from_dict(d)
assert isinstance(tc, MyClass)
assert isinstance(tc.a, TensorDict)
assert tc.batch_size == torch.Size([10])

Expand Down

0 comments on commit ebe4a58

Please sign in to comment.