Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Seed workers in TensorDict.map #562

Merged
merged 43 commits into from
Nov 24, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
43 commits
Select commit Hold shift + click to select a range
6bd248d
init
vmoens Nov 21, 2023
237dddf
init
vmoens Nov 21, 2023
2216e4e
amend
vmoens Nov 21, 2023
b607c0e
edit docstring
vmoens Nov 21, 2023
aa42433
maxtasksperchild
vmoens Nov 22, 2023
152612c
fix
vmoens Nov 22, 2023
8ca93b6
TestMap.fn
vmoens Nov 22, 2023
7c63518
timeout in get
vmoens Nov 22, 2023
01b9ee0
processes=num_workers
vmoens Nov 22, 2023
9ca19c7
amend
vmoens Nov 22, 2023
80ee04f
amend
vmoens Nov 22, 2023
2a4658e
amend
vmoens Nov 22, 2023
9bf4f48
amend
vmoens Nov 22, 2023
82b668e
amend
vmoens Nov 22, 2023
47f4bf7
amend
vmoens Nov 22, 2023
c8006d2
amend
vmoens Nov 22, 2023
e42f403
amend
vmoens Nov 22, 2023
988feda
amend
vmoens Nov 22, 2023
f6fdbc6
amend
vmoens Nov 22, 2023
b48db6f
amend
vmoens Nov 22, 2023
11ff0a0
amend
vmoens Nov 22, 2023
1c70cc4
amend
vmoens Nov 22, 2023
8723bde
amend
vmoens Nov 22, 2023
28921aa
amend
vmoens Nov 22, 2023
7548cd2
amend
vmoens Nov 22, 2023
bb168e2
amend
vmoens Nov 22, 2023
975f85b
amend
vmoens Nov 22, 2023
14666e6
amend
vmoens Nov 22, 2023
2b490ff
amend
vmoens Nov 22, 2023
70cc6ee
amend
vmoens Nov 22, 2023
4f5737e
amend
vmoens Nov 22, 2023
838ed0e
amend
vmoens Nov 22, 2023
3c4d96a
amend
vmoens Nov 22, 2023
0d336e1
amend
vmoens Nov 22, 2023
b92cea6
amend
vmoens Nov 22, 2023
4435f38
amend
vmoens Nov 22, 2023
1a2e65e
amend
vmoens Nov 22, 2023
4edd592
amend
vmoens Nov 22, 2023
5223dd1
amend
vmoens Nov 23, 2023
3272e80
Merge remote-tracking branch 'origin/main' into seeding-pool
vmoens Nov 23, 2023
2fa557f
amend
vmoens Nov 23, 2023
401f055
amend
vmoens Nov 23, 2023
a70e95e
fix
vmoens Nov 24, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 52 additions & 9 deletions tensordict/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
_GENERIC_NESTED_ERR,
_is_tensorclass,
_KEY_ERROR,
_proc_init,
_shape,
_split_tensordict,
_td_fields,
Expand Down Expand Up @@ -2910,10 +2911,12 @@ def map(
self,
fn: Callable,
dim: int = 0,
num_workers: int = None,
chunksize: int = None,
num_chunks: int = None,
pool: mp.Pool = None,
num_workers: int | None = None,
chunksize: int | None = None,
num_chunks: int | None = None,
pool: mp.Pool | None = None,
generator: torch.Generator | None = None,
max_tasks_per_child: int | None = None,
):
"""Maps a function to splits of the tensordict across one dimension.

Expand All @@ -2938,16 +2941,42 @@ def map(
of workers. For very large tensordicts, such large chunks
may not fit in memory for the operation to be done and
more chunks may be needed to make the operation practically
doable. This argument is exclusive with num_chunks.
doable. This argument is exclusive with ``num_chunks``.
num_chunks (int, optional): the number of chunks to split the tensordict
into. If none is provided, the number of chunks will equate the number
of workers. For very large tensordicts, such large chunks
may not fit in memory for the operation to be done and
more chunks may be needed to make the operation practically
doable. This argument is exclusive with chunksize.
doable. This argument is exclusive with ``chunksize``.
pool (mp.Pool, optional): a multiprocess Pool instance to use
to execute the job. If none is provided, a pool will be created
within the ``map`` method.
generator (torch.Generator, optional): a generator to use for seeding.
A base seed will be generated from it, and each worker
of the pool will be seeded with the provided seed incremented
by a unique integer from ``0`` to ``num_workers``. If no generator
is provided, a random integer will be used as seed.
To work with unseeded workers, a pool should be created separately
and passed to :meth:`map` directly.
.. note::
Caution should be taken when providing a low-valued seed as
this can cause autocorrelation between experiments, example:
if 8 workers are asked and the seed is 4, the workers seed will
range from 4 to 11. If the seed is 5, the workers seed will range
from 5 to 12. These two experiments will have an overlap of 7
seeds, which can have unexpected effects on the results.

.. note::
vmoens marked this conversation as resolved.
Show resolved Hide resolved
The goal of seeding the workers is to have independent seed on
each worker, and NOT to have reproducible results across calls
of the `map` method. In other words, two experiments may and
probably will return different results as it is impossible to
know which worker will pick which job. However, we can make sure
that each worker has a different seed and that the pseudo-random
operations on each will be uncorrelated.
max_tasks_per_child (int, optional): the maximum number of jobs picked
by every child process. Defaults to ``None``, i.e., no restriction
on the number of jobs.

Examples:
>>> import torch
Expand Down Expand Up @@ -2976,7 +3005,21 @@ def map(
if pool is None:
if num_workers is None:
num_workers = mp.cpu_count() # Get the number of CPU cores
with mp.Pool(num_workers) as pool:
if generator is None:
generator = torch.Generator()
seed = (
torch.empty((), dtype=torch.int64).random_(generator=generator).item()
)

queue = mp.Queue(maxsize=num_workers)
for i in range(num_workers):
queue.put(i)
with mp.Pool(
processes=num_workers,
initializer=_proc_init,
initargs=(seed, queue),
maxtasksperchild=max_tasks_per_child,
) as pool:
return self.map(
fn, dim=dim, chunksize=chunksize, num_chunks=num_chunks, pool=pool
)
Expand All @@ -2989,8 +3032,8 @@ def map(

self_split = _split_tensordict(self, chunksize, num_chunks, num_workers, dim)
chunksize = 1
out = pool.imap(fn, self_split, chunksize)
out = torch.cat(list(out), dim)
imap = pool.imap(fn, self_split, chunksize)
out = torch.cat(list(imap), dim)
return out

# Functorch compatibility
Expand Down
2 changes: 1 addition & 1 deletion tensordict/memmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -537,7 +537,7 @@ def __getitem__(self, item):
"isn't supported at the moment."
) from err
raise
if out.data_ptr() == self.data_ptr():
if out.storage().data_ptr() == self.storage().data_ptr():
out = MemoryMappedTensor(out)
out._handler = self._handler
out._filename = self._filename
Expand Down
11 changes: 10 additions & 1 deletion tensordict/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import collections
import dataclasses
import inspect

import math
import os

Expand Down Expand Up @@ -50,6 +49,7 @@
from torch import Tensor
from torch._C import _disabled_torch_function_impl
from torch.nn.parameter import _ParameterMeta
from torch.utils.data._utils.worker import _generate_state

if TYPE_CHECKING:
from tensordict.memmap_deprec import MemmapTensor as _MemmapTensor
Expand Down Expand Up @@ -1722,3 +1722,12 @@ def _legacy_lazy(func):
)
func.LEGACY = True
return func


# Process initializer for map
def _proc_init(base_seed, queue):
worker_id = queue.get(timeout=10)
seed = base_seed + worker_id
torch.manual_seed(seed)
np_seed = _generate_state(base_seed, worker_id)
np.random.seed(np_seed)
103 changes: 103 additions & 0 deletions test/test_tensordict.py
Original file line number Diff line number Diff line change
Expand Up @@ -6546,6 +6546,109 @@ def test_modules(self, as_module):
assert y._tensor.shape[0] == param_batch


class TestMap:
"""Tests for TensorDict.map that are independent from tensordict's type."""

@classmethod
def get_rand_incr(cls, td):
# torch
td["r"] = td["r"] + torch.randint(0, 100, ()).item()
# numpy
td["s"] = td["s"] + np.random.randint(0, 100, ()).item()
return td

def test_map_seed(self):
pytest.skip(
reason="Using max_tasks_per_child is unstable and can cause multiple processes to start over even though all jobs are completed",
)

if mp.get_start_method(allow_none=True) is None:
mp.set_start_method("spawn")
td = TensorDict(
{
"r": torch.zeros(20, dtype=torch.int),
"s": torch.zeros(20, dtype=torch.int),
"c": torch.arange(20),
},
batch_size=[20],
)
generator = torch.Generator()
# we use 4 workers with max 5 items each,
# making sure that no worker does more than any other.
generator.manual_seed(0)
td_out_0 = td.map(
TestMap.get_rand_incr,
num_workers=4,
generator=generator,
chunksize=1,
max_tasks_per_child=5,
)
print("got 1")
generator.manual_seed(0)
td_out_1 = td.map(
TestMap.get_rand_incr,
num_workers=4,
generator=generator,
chunksize=1,
max_tasks_per_child=5,
)
print("got 2")
# we cannot know which worker picks which job, but since they will all have
# a seed from 0 to 4 and produce 1 number each, we can chekc that
# those numbers are exactly what we were expecting.
assert (td_out_0["r"].sort().values == td_out_1["r"].sort().values).all(), (
td_out_0["r"].sort().values,
td_out_1["r"].sort().values,
)
assert (td_out_0["s"].sort().values == td_out_1["s"].sort().values).all(), (
td_out_0["s"].sort().values,
td_out_1["s"].sort().values,
)

def test_map_seed_single(self):
# A cheap version of the previous test
if mp.get_start_method(allow_none=True) is None:
mp.set_start_method("spawn")
td = TensorDict(
{
"r": torch.zeros(20, dtype=torch.int),
"s": torch.zeros(20, dtype=torch.int),
"c": torch.arange(20),
},
batch_size=[20],
)
generator = torch.Generator()
# we use 4 workers with max 5 items each,
# making sure that no worker does more than any other.
generator.manual_seed(0)
td_out_0 = td.map(
TestMap.get_rand_incr,
num_workers=1,
generator=generator,
chunksize=1,
)
print("got 1")
generator.manual_seed(0)
td_out_1 = td.map(
TestMap.get_rand_incr,
num_workers=1,
generator=generator,
chunksize=1,
)
print("got 2")
# we cannot know which worker picks which job, but since they will all have
# a seed from 0 to 4 and produce 1 number each, we can chekc that
# those numbers are exactly what we were expecting.
assert (td_out_0["r"].sort().values == td_out_1["r"].sort().values).all(), (
td_out_0["r"].sort().values,
td_out_1["r"].sort().values,
)
assert (td_out_0["s"].sort().values == td_out_1["s"].sort().values).all(), (
td_out_0["s"].sort().values,
td_out_1["s"].sort().values,
)


if __name__ == "__main__":
args, unknown = argparse.ArgumentParser().parse_known_args()
pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)
Loading