diff --git a/rpyc/core/consts.py b/rpyc/core/consts.py index 5604c2c6..877999bb 100644 --- a/rpyc/core/consts.py +++ b/rpyc/core/consts.py @@ -32,6 +32,7 @@ HANDLE_INSPECT = 16 HANDLE_BUFFITER = 17 HANDLE_OLDSLICING = 18 +HANDLE_CTXEXIT = 19 # optimized exceptions EXC_STOP_ITERATION = 1 diff --git a/rpyc/core/netref.py b/rpyc/core/netref.py index 9c009741..aacd0166 100644 --- a/rpyc/core/netref.py +++ b/rpyc/core/netref.py @@ -14,7 +14,7 @@ '__dir__', '__doc__', '__getattr__', '__getattribute__', '__hash__', '__init__', '__metaclass__', '__module__', '__new__', '__reduce__', '__reduce_ex__', '__repr__', '__setattr__', '__slots__', '__str__', - '__weakref__', '__dict__', '__members__', '__methods__', + '__weakref__', '__dict__', '__members__', '__methods__', '__exit__', ]) """the set of attributes that are local to the netref object""" @@ -173,7 +173,8 @@ def __repr__(self): return syncreq(self, consts.HANDLE_REPR) def __str__(self): return syncreq(self, consts.HANDLE_STR) - + def __exit__(self, exc, typ, tb): + return syncreq(self, consts.HANDLE_CTXEXIT, exc) # can't pass type nor traceback # support for pickling netrefs def __reduce_ex__(self, proto): return pickle.loads, (syncreq(self, consts.HANDLE_PICKLE, proto),) diff --git a/rpyc/core/protocol.py b/rpyc/core/protocol.py index d71a3619..5345a14d 100644 --- a/rpyc/core/protocol.py +++ b/rpyc/core/protocol.py @@ -634,6 +634,15 @@ def _handle_setattr(self, oid, name, value): return self._access_attr(oid, name, (value,), "_rpyc_setattr", "allow_setattr", setattr) def _handle_callattr(self, oid, name, args, kwargs): return self._handle_getattr(oid, name)(*args, **dict(kwargs)) + def _handle_ctxexit(self, oid, exc): + if exc: + try: + raise exc + except: + exc, typ, tb = sys.exc_info() + else: + typ = tb = None + return self._handle_getattr(oid, "__exit__")(exc, typ, tb) def _handle_pickle(self, oid, proto): if not self._config["allow_pickle"]: raise ValueError("pickling is disabled") diff --git a/tests/test_context_managers.py b/tests/test_context_managers.py index 761d4ca4..1b61b2b1 100644 --- a/tests/test_context_managers.py +++ b/tests/test_context_managers.py @@ -4,16 +4,24 @@ from contextlib import contextmanager -on_context_enter = False -on_context_exit = False class MyService(rpyc.Service): + + def exposed_reset(self): + global on_context_enter, on_context_exit, on_context_exc + on_context_enter = False + on_context_exit = False + on_context_exc = False + @contextmanager def exposed_context(self, y): - global on_context_enter, on_context_exit + global on_context_enter, on_context_exit, on_context_exc on_context_enter = True try: yield 17 + y + except: + on_context_exc = True + raise finally: on_context_exit = True @@ -21,21 +29,37 @@ def exposed_context(self, y): class TestContextManagers(unittest.TestCase): def setUp(self): self.conn = rpyc.connect_thread(remote_service=MyService) + self.conn.root.reset() def tearDown(self): self.conn.close() - + def test_context(self): with self.conn.root.context(3) as x: print( "entering test" ) self.assertTrue(on_context_enter) + self.assertFalse(on_context_exc) self.assertFalse(on_context_exit) print( "got past context enter" ) self.assertEqual(x, 20) print( "got past x=20" ) + self.assertFalse(on_context_exc) self.assertTrue(on_context_exit) print( "got past on_context_exit" ) + def test_context_exception(self): + class MyException(Exception): + pass + + with self.assertRaises(MyException): + with self.conn.root.context(3): + self.assertTrue(on_context_enter) + self.assertFalse(on_context_exc) + self.assertFalse(on_context_exit) + raise MyException() + + self.assertTrue(on_context_exc) + self.assertTrue(on_context_exit) if __name__ == "__main__": unittest.main()