From 4435f38a6af8d7447819dc2fec364211ab723831 Mon Sep 17 00:00:00 2001 From: vmoens Date: Wed, 22 Nov 2023 17:05:56 +0000 Subject: [PATCH] amend --- tensordict/base.py | 4 ++-- tensordict/utils.py | 4 ++-- test/test_tensordict.py | 17 ++++++++++++----- 3 files changed, 16 insertions(+), 9 deletions(-) diff --git a/tensordict/base.py b/tensordict/base.py index f0dd149fd..b5b4182a7 100644 --- a/tensordict/base.py +++ b/tensordict/base.py @@ -3003,8 +3003,8 @@ def map( at little cost. """ - # from torch import multiprocessing as mp - import multiprocessing as mp + from torch import multiprocessing as mp + # import multiprocessing as mp if pool is None: if num_workers is None: diff --git a/tensordict/utils.py b/tensordict/utils.py index 009cb4313..911341dcd 100644 --- a/tensordict/utils.py +++ b/tensordict/utils.py @@ -1726,8 +1726,8 @@ def _legacy_lazy(func): # Process initializer for map def _proc_init(base_seed, queue): print('init worker', os.getpid()) - # if queue.empty(): - # exit() + if queue.empty(): + return worker_id = queue.get(timeout=10) print('worker id', worker_id) seed = base_seed + worker_id diff --git a/test/test_tensordict.py b/test/test_tensordict.py index f7d01b1e4..5868635bc 100644 --- a/test/test_tensordict.py +++ b/test/test_tensordict.py @@ -8,6 +8,7 @@ import json import os import re +import time import uuid import numpy as np @@ -6549,18 +6550,24 @@ def test_modules(self, as_module): assert y.dims == (d0,) assert y._tensor.shape[0] == param_batch - +COUNTER = 0 class TestMap: """Tests for TensorDict.map that are independent from tensordict's type.""" @classmethod def get_rand_incr(cls, td): - print('worker', os.getpid()) + global COUNTER + if COUNTER == 5: + print('pausing') + time.sleep(1000) + return + COUNTER += 1 + # print('worker', os.getpid()) # torch td["r"] += torch.randint(0, 100, ()).item() # numpy td["s"] += np.random.randint(0, 100, ()).item() - print(td['c']) + # print(td['c']) return td def test_map_seed(self): @@ -6582,7 +6589,7 @@ def test_map_seed(self): num_workers=4, generator=generator, chunksize=1, - max_tasks_per_child=6, + max_tasks_per_child=5, ) print('first') generator.manual_seed(0) @@ -6591,7 +6598,7 @@ def test_map_seed(self): num_workers=4, generator=generator, chunksize=1, - max_tasks_per_child=6, + max_tasks_per_child=5, ) print('second') # we cannot know which worker picks which job, but since they will all have