Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Seed workers in TensorDict.map #562

Merged
merged 43 commits into from
Nov 24, 2023
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
Show all changes
43 commits
Select commit Hold shift + click to select a range
6bd248d
init
vmoens Nov 21, 2023
237dddf
init
vmoens Nov 21, 2023
2216e4e
amend
vmoens Nov 21, 2023
b607c0e
edit docstring
vmoens Nov 21, 2023
aa42433
maxtasksperchild
vmoens Nov 22, 2023
152612c
fix
vmoens Nov 22, 2023
8ca93b6
TestMap.fn
vmoens Nov 22, 2023
7c63518
timeout in get
vmoens Nov 22, 2023
01b9ee0
processes=num_workers
vmoens Nov 22, 2023
9ca19c7
amend
vmoens Nov 22, 2023
80ee04f
amend
vmoens Nov 22, 2023
2a4658e
amend
vmoens Nov 22, 2023
9bf4f48
amend
vmoens Nov 22, 2023
82b668e
amend
vmoens Nov 22, 2023
47f4bf7
amend
vmoens Nov 22, 2023
c8006d2
amend
vmoens Nov 22, 2023
e42f403
amend
vmoens Nov 22, 2023
988feda
amend
vmoens Nov 22, 2023
f6fdbc6
amend
vmoens Nov 22, 2023
b48db6f
amend
vmoens Nov 22, 2023
11ff0a0
amend
vmoens Nov 22, 2023
1c70cc4
amend
vmoens Nov 22, 2023
8723bde
amend
vmoens Nov 22, 2023
28921aa
amend
vmoens Nov 22, 2023
7548cd2
amend
vmoens Nov 22, 2023
bb168e2
amend
vmoens Nov 22, 2023
975f85b
amend
vmoens Nov 22, 2023
14666e6
amend
vmoens Nov 22, 2023
2b490ff
amend
vmoens Nov 22, 2023
70cc6ee
amend
vmoens Nov 22, 2023
4f5737e
amend
vmoens Nov 22, 2023
838ed0e
amend
vmoens Nov 22, 2023
3c4d96a
amend
vmoens Nov 22, 2023
0d336e1
amend
vmoens Nov 22, 2023
b92cea6
amend
vmoens Nov 22, 2023
4435f38
amend
vmoens Nov 22, 2023
1a2e65e
amend
vmoens Nov 22, 2023
4edd592
amend
vmoens Nov 22, 2023
5223dd1
amend
vmoens Nov 23, 2023
3272e80
Merge remote-tracking branch 'origin/main' into seeding-pool
vmoens Nov 23, 2023
2fa557f
amend
vmoens Nov 23, 2023
401f055
amend
vmoens Nov 23, 2023
a70e95e
fix
vmoens Nov 24, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 34 additions & 3 deletions tensordict/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
_GENERIC_NESTED_ERR,
_is_tensorclass,
_KEY_ERROR,
_proc_init,
_shape,
_split_tensordict,
_td_fields,
Expand Down Expand Up @@ -2917,6 +2918,7 @@ def map(
chunksize: int = None,
num_chunks: int = None,
pool: mp.Pool = None,
seed: int = None,
):
"""Maps a function to splits of the tensordict across one dimension.

Expand All @@ -2941,16 +2943,38 @@ def map(
of workers. For very large tensordicts, such large chunks
may not fit in memory for the operation to be done and
more chunks may be needed to make the operation practically
doable. This argument is exclusive with num_chunks.
doable. This argument is exclusive with ``num_chunks``.
num_chunks (int, optional): the number of chunks to split the tensordict
into. If none is provided, the number of chunks will equate the number
of workers. For very large tensordicts, such large chunks
may not fit in memory for the operation to be done and
more chunks may be needed to make the operation practically
doable. This argument is exclusive with chunksize.
doable. This argument is exclusive with ``chunksize``.
pool (mp.Pool, optional): a multiprocess Pool instance to use
to execute the job. If none is provided, a pool will be created
within the ``map`` method.
seed (integer, optional): the initial seed of the pool. Each member
vmoens marked this conversation as resolved.
Show resolved Hide resolved
of the pool will be seeded with the provided seed incremented
by a unique integer from ``0`` to ``num_workers``. If no seed
is provided, a random integer will be used as seed.
To work with unseeded workers, a pool should be created separately
and passed to :meth:`map` directly.
.. note::
Caution should be taken when providing a low-valued seed as
this can cause autocorrelation between experiments, example:
if 8 workers are asked and the seed is 4, the workers seed will
range from 4 to 11. If the seed is 5, the workers seed will range
from 5 to 12. These two experiments will have an overlap of 7
seeds, which can have unexpected effects on the results.

.. note::
vmoens marked this conversation as resolved.
Show resolved Hide resolved
The goal of seeding the workers is to have independent seed on
each worker, and NOT to have reproducible results across calls
of the `map` method. In other words, two experiments may and
probably will return different results as it is impossible to
know which worker will pick which job. However, we can make sure
that each worker has a different seed and that the pseudo-random
operations on each will be uncorrelated.

Examples:
>>> import torch
Expand Down Expand Up @@ -2979,7 +3003,14 @@ def map(
if pool is None:
if num_workers is None:
num_workers = mp.cpu_count() # Get the number of CPU cores
with mp.Pool(num_workers) as pool:
if seed is None:
seed = torch.empty((), dtype=torch.int64).random_()
queue = mp.Queue(maxsize=num_workers)
for i in range(num_workers):
queue.put(i)
with mp.Pool(
num_workers, initializer=_proc_init, initargs=(seed, queue)
) as pool:
return self.map(
fn, dim=dim, chunksize=chunksize, num_chunks=num_chunks, pool=pool
)
Expand Down
11 changes: 10 additions & 1 deletion tensordict/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import collections
import dataclasses
import inspect

import math
import os

Expand Down Expand Up @@ -50,6 +49,7 @@
from torch import Tensor
from torch._C import _disabled_torch_function_impl
from torch.nn.parameter import _ParameterMeta
from torch.utils.data._utils.worker import _generate_state

if TYPE_CHECKING:
from tensordict.memmap_deprec import MemmapTensor as _MemmapTensor
Expand Down Expand Up @@ -1721,3 +1721,12 @@ def _legacy_lazy(func):
)
func.LEGACY = True
return func


# Process initializer for map
def _proc_init(base_seed, queue):
worker_id = queue.get()
seed = base_seed + worker_id
torch.manual_seed(seed)
np_seed = _generate_state(base_seed, worker_id)
np.random.seed(np_seed)
vmoens marked this conversation as resolved.
Show resolved Hide resolved
37 changes: 37 additions & 0 deletions test/test_tensordict.py
Original file line number Diff line number Diff line change
Expand Up @@ -6550,6 +6550,43 @@ def test_modules(self, as_module):
assert y._tensor.shape[0] == param_batch


class TestMap:
"""Tests for TensorDict.map that are independent from tensordict's type."""

@classmethod
def get_rand_incr(cls, td):
# torch
td["r"] = td["r"] + torch.randint(0, 10, ())
# numpy
td["s"] = td["s"] + np.random.randint(0, 10, ())
return td

def test_map_seed(self):
td = TensorDict(
{
"r": torch.zeros(20, dtype=torch.int),
"s": torch.zeros(20, dtype=torch.int),
},
batch_size=[20],
)
# we use 20 workers to make sure that each worker has one item to work with
# Using less could cause undeterministic behaviour depending on the workers'
# speed, since we cannot tell who will pick which job.
td_out_0 = td.map(self.get_rand_incr, num_workers=20, seed=0, chunksize=1)
td_out_1 = td.map(self.get_rand_incr, num_workers=20, seed=0, chunksize=1)
# we cannot know which worker picks which job, but since they will all have
# a seed from 0 to 4 and produce 1 number each, we can chekc that
# those numbers are exactly what we were expecting.
assert (td_out_0["r"].sort().values == td_out_1["r"].sort().values).all(), (
td_out_0["r"].sort().values,
td_out_1["r"].sort().values,
)
assert (td_out_0["s"].sort().values == td_out_1["s"].sort().values).all(), (
td_out_0["s"].sort().values,
td_out_1["s"].sort().values,
)


if __name__ == "__main__":
args, unknown = argparse.ArgumentParser().parse_known_args()
pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)
Loading