Skip to content

Commit

Permalink
[BugFix] _FileHandler for windows (#577)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens authored Nov 24, 2023
1 parent 4dbabc6 commit 91ffde1
Showing 1 changed file with 77 additions and 45 deletions.
122 changes: 77 additions & 45 deletions tensordict/memmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -546,55 +546,87 @@ def __getitem__(self, item):
return out


class _FileHandler:
if sys.platform == "linux":
_dir_candidates = ["/dev/shm"]
else:
_dir_candidates = []

def __init__(self, size, fd=-1, filename=None):
# borrowed from mp.heap
self.size = size
# if filename is None:
if fd == -1:
self.fd, name = tempfile.mkstemp(
prefix="pym-%d-" % os.getpid(), dir=self._choose_dir(size)
)
# self.filename = name
os.unlink(name)
util.Finalize(self, os.close, (self.fd,))
os.ftruncate(self.fd, size)
#####################
# File handler
# borrowed from mp.heap
if sys.platform == "win32":
import _winapi

class _FileHandler:
_rand = tempfile._RandomNameSequence()

def __init__(self, size):
self.size = size
for _ in range(100):
name = "pym-%d-%s" % (os.getpid(), next(self._rand))
buf = mmap.mmap(-1, size, tagname=name)
if _winapi.GetLastError() == 0:
break
# We have reopened a preexisting mmap.
buf.close()
else:
raise FileExistsError("Cannot find name for new mmap")
self.name = name
self.buffer = buf
self._state = (self.size, self.name)

def __getstate__(self):
from multiprocessing.context import assert_spawning

assert_spawning(self)
return self._state

def __setstate__(self, state):
self.size, self.name = self._state = state
# Reopen existing mmap
self.buffer = mmap.mmap(-1, self.size, tagname=self.name)
# XXX Temporarily preventing buildbot failures while determining
# XXX the correct long-term fix. See issue 23060
# assert _winapi.GetLastError() == _winapi.ERROR_ALREADY_EXISTS

else:

class _FileHandler:
if sys.platform == "linux":
_dir_candidates = ["/dev/shm"]
else:
self.fd = fd
# else:
# self.filename = filename
self.buffer = mmap.mmap(self.fd, self.size)

def _choose_dir(self, size):
# Choose a non-storage backed directory if possible,
# to improve performance
for d in self._dir_candidates:
st = os.statvfs(d)
if st.f_bavail * st.f_frsize >= size: # enough free space?
return d
tmpdir = util.get_temp_dir()
return tmpdir


def _reduce_handler(handler):
if handler.fd == -1:
raise ValueError(
"Handler is unpicklable because " "forking was enabled when it was created"
)
return _rebuild_handler, (handler.size, reduction.DupFd(handler.fd))
_dir_candidates = []

def __init__(self, size, fd=-1):
self.size = size
self.fd = fd
if fd == -1:
self.fd, name = tempfile.mkstemp(
prefix="pym-%d-" % os.getpid(), dir=self._choose_dir(size)
)
os.unlink(name)
util.Finalize(self, os.close, (self.fd,))
os.ftruncate(self.fd, size)
self.buffer = mmap.mmap(self.fd, self.size)

def _choose_dir(self, size):
# Choose a non-storage backed directory if possible,
# to improve performance
for d in self._dir_candidates:
st = os.statvfs(d)
if st.f_bavail * st.f_frsize >= size: # enough free space?
return d
tmpdir = util.get_temp_dir()
return tmpdir

def _reduce_handler(handler):
if handler.fd == -1:
raise ValueError(
"Handler is unpicklable because "
"forking was enabled when it was created"
)
return _rebuild_handler, (handler.size, reduction.DupFd(handler.fd))

def _rebuild_handler(size, dupfd):
detached = dupfd.detach()
return _FileHandler(size, detached)

def _rebuild_handler(size, dupfd):
detached = dupfd.detach()
return _FileHandler(size, detached)

reduction.register(_FileHandler, _reduce_handler)
reduction.register(_FileHandler, _reduce_handler)


def _reduce_memmap(memmap_tensor):
Expand Down

0 comments on commit 91ffde1

Please sign in to comment.