Skip to content

Commit

Permalink
amend
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens committed Nov 22, 2023
1 parent b92cea6 commit 4435f38
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 9 deletions.
4 changes: 2 additions & 2 deletions tensordict/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3003,8 +3003,8 @@ def map(
at little cost.
"""
# from torch import multiprocessing as mp
import multiprocessing as mp
from torch import multiprocessing as mp
# import multiprocessing as mp

if pool is None:
if num_workers is None:
Expand Down
4 changes: 2 additions & 2 deletions tensordict/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1726,8 +1726,8 @@ def _legacy_lazy(func):
# Process initializer for map
def _proc_init(base_seed, queue):
print('init worker', os.getpid())
# if queue.empty():
# exit()
if queue.empty():
return
worker_id = queue.get(timeout=10)
print('worker id', worker_id)
seed = base_seed + worker_id
Expand Down
17 changes: 12 additions & 5 deletions test/test_tensordict.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import json
import os
import re
import time
import uuid

import numpy as np
Expand Down Expand Up @@ -6549,18 +6550,24 @@ def test_modules(self, as_module):
assert y.dims == (d0,)
assert y._tensor.shape[0] == param_batch


COUNTER = 0
class TestMap:
"""Tests for TensorDict.map that are independent from tensordict's type."""

@classmethod
def get_rand_incr(cls, td):
print('worker', os.getpid())
global COUNTER
if COUNTER == 5:
print('pausing')
time.sleep(1000)
return
COUNTER += 1
# print('worker', os.getpid())
# torch
td["r"] += torch.randint(0, 100, ()).item()
# numpy
td["s"] += np.random.randint(0, 100, ()).item()
print(td['c'])
# print(td['c'])
return td

def test_map_seed(self):
Expand All @@ -6582,7 +6589,7 @@ def test_map_seed(self):
num_workers=4,
generator=generator,
chunksize=1,
max_tasks_per_child=6,
max_tasks_per_child=5,
)
print('first')
generator.manual_seed(0)
Expand All @@ -6591,7 +6598,7 @@ def test_map_seed(self):
num_workers=4,
generator=generator,
chunksize=1,
max_tasks_per_child=6,
max_tasks_per_child=5,
)
print('second')
# we cannot know which worker picks which job, but since they will all have
Expand Down

0 comments on commit 4435f38

Please sign in to comment.