Skip to content

Commit

Permalink
handle a ThreadHandleType (#679)
Browse files Browse the repository at this point in the history
* handle a ThreadHandleType

* handle special case of _thread._ExceptHookArgs

* use threading.ExceptHookArgs for pypy

* handle when thread doesn't have native_id
  • Loading branch information
mmckerns authored Sep 13, 2024
1 parent 15d7c6d commit 8b86f50
Show file tree
Hide file tree
Showing 2 changed files with 76 additions and 2 deletions.
32 changes: 30 additions & 2 deletions dill/_dill.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,14 @@
from pickle import GLOBAL, POP
from _thread import LockType
from _thread import RLock as RLockType
try:
from _thread import _ExceptHookArgs as ExceptHookArgsType
except ImportError:
ExceptHookArgsType = None
try:
from _thread import _ThreadHandle as ThreadHandleType
except ImportError:
ThreadHandleType = None
#from io import IOBase
from types import CodeType, FunctionType, MethodType, GeneratorType, \
TracebackType, FrameType, ModuleType, BuiltinMethodType
Expand Down Expand Up @@ -775,6 +783,14 @@ def _create_typing_tuple(argz, *args): #NOTE: workaround python/cpython#94245
return typing.Tuple[()]
return typing.Tuple[argz]

if ThreadHandleType:
def _create_thread_handle(ident, done, *args): #XXX: ignores 'blocking'
from threading import _make_thread_handle
handle = _make_thread_handle(ident)
if done:
handle._set_done()
return handle

def _create_lock(locked, *args): #XXX: ignores 'blocking'
from threading import Lock
lock = Lock()
Expand Down Expand Up @@ -1306,7 +1322,15 @@ def save_generic_alias(pickler, obj):
logger.trace(pickler, "# Ga2")
return

@register(LockType)
if ThreadHandleType:
@register(ThreadHandleType)
def save_thread_handle(pickler, obj):
logger.trace(pickler, "Th: %s", obj)
pickler.save_reduce(_create_thread_handle, (obj.ident, obj.is_done()), obj=obj)
logger.trace(pickler, "# Th")
return

@register(LockType) #XXX: copied Thread will have new Event (due to new Lock)
def save_lock(pickler, obj):
logger.trace(pickler, "Lo: %s", obj)
pickler.save_reduce(_create_lock, (obj.locked(),), obj=obj)
Expand Down Expand Up @@ -1773,7 +1797,7 @@ def save_type(pickler, obj, postproc_list=None):
logger.trace(pickler, "# T6")
return

# special cases: NoneType, NotImplementedType, EllipsisType, EnumMeta
# special caes: NoneType, NotImplementedType, EllipsisType, EnumMeta, etc
elif obj is type(None):
logger.trace(pickler, "T7: %s", obj)
#XXX: pickler.save_reduce(type, (None,), obj=obj)
Expand All @@ -1791,6 +1815,10 @@ def save_type(pickler, obj, postproc_list=None):
logger.trace(pickler, "T7: %s", obj)
pickler.write(GLOBAL + b'enum\nEnumMeta\n')
logger.trace(pickler, "# T7")
elif obj is ExceptHookArgsType: #NOTE: must be after NoneType for pypy
logger.trace(pickler, "T7: %s", obj)
pickler.write(GLOBAL + b'threading\nExceptHookArgs\n')
logger.trace(pickler, "# T7")

else:
_byref = getattr(pickler, '_byref', None)
Expand Down
46 changes: 46 additions & 0 deletions dill/tests/test_threads.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
#!/usr/bin/env python
#
# Author: Mike McKerns (mmckerns @caltech and @uqfoundation)
# Copyright (c) 2024 The Uncertainty Quantification Foundation.
# License: 3-clause BSD. The full license text is available at:
# - https://github.com/uqfoundation/dill/blob/master/LICENSE

import dill
dill.settings['recurse'] = True


def test_new_thread():
import threading
t = threading.Thread()
t_ = dill.copy(t)
assert t.is_alive() == t_.is_alive()
for i in ['daemon','name','ident','native_id']:
if hasattr(t, i):
assert getattr(t, i) == getattr(t_, i)

def test_run_thread():
import threading
t = threading.Thread()
t.start()
t_ = dill.copy(t)
assert t.is_alive() == t_.is_alive()
for i in ['daemon','name','ident','native_id']:
if hasattr(t, i):
assert getattr(t, i) == getattr(t_, i)

def test_join_thread():
import threading
t = threading.Thread()
t.start()
t.join()
t_ = dill.copy(t)
assert t.is_alive() == t_.is_alive()
for i in ['daemon','name','ident','native_id']:
if hasattr(t, i):
assert getattr(t, i) == getattr(t_, i)


if __name__ == '__main__':
test_new_thread()
test_run_thread()
test_join_thread()

0 comments on commit 8b86f50

Please sign in to comment.