Issue #6: fix various problems related to the callback wrappers, including correctly looking up __exit__ on the type in register_exit()

This commit is contained in:
Nick Coghlan 2012-01-04 00:26:17 +10:00
parent 8bca9f8cd0
commit 0cbb244909
3 changed files with 44 additions and 15 deletions

View file

@ -12,7 +12,8 @@ Release History
attributes in addition to accepting exit callbacks directly
* Issue #1: Add ContextStack.preserve() to move all registered callbacks to
a new ContextStack object
* Wrapped callbacks now use functools.wraps to aid in introspection
* Wrapped callbacks now expose __wrapped__ (for direct callbacks) or __self__
(for context manager methods) attributes to aid in introspection
* Moved version number to a VERSION.txt file (read by both docs and setup.py)
* Added NEWS.rst (and incorporated into documentation)

View file

@ -164,6 +164,13 @@ class ContextStack(object):
self._callbacks = deque()
return new_stack
def _register_cm_exit(self, cm, cm_exit):
"""Helper to correctly register callbacks to __exit__ methods"""
def _exit_wrapper(*exc_details):
return cm_exit(cm, *exc_details)
_exit_wrapper.__self__ = cm
self.register_exit(_exit_wrapper)
def register_exit(self, callback):
"""Registers a callback with the standard __exit__ method signature
@ -172,11 +179,13 @@ class ContextStack(object):
Also accepts any object with an __exit__ method (registering the
method instead of the object itself)
"""
_cb_type = type(callback)
try:
exit = callback.__exit__
exit = _cb_type.__exit__
except AttributeError:
exit = callback
self._callbacks.append(exit)
self._callbacks.append(callback)
else:
self._register_cm_exit(callback, exit)
return callback # Allow use as a decorator
def register(self, callback, *args, **kwds):
@ -184,10 +193,12 @@ class ContextStack(object):
Cannot suppress exceptions.
"""
@wraps(callback)
def _wrapper(exc_type, exc, tb):
def _exit_wrapper(exc_type, exc, tb):
callback(*args, **kwds)
self.register_exit(_wrapper)
# We changed the signature, so using @wraps is not appropriate, but
# setting __wrapped__ may still help with introspection
_exit_wrapper.__wrapped__ = callback
self.register_exit(_exit_wrapper)
def enter_context(self, cm):
"""Enters the supplied context manager
@ -199,9 +210,7 @@ class ContextStack(object):
_cm_type = type(cm)
_exit = _cm_type.__exit__
result = _cm_type.__enter__(cm)
def _exit_wrapper(*exc_details):
return _exit(cm, *exc_details)
self.register_exit(_exit_wrapper)
self._register_cm_exit(cm, _exit)
return result
def close(self):

View file

@ -330,8 +330,9 @@ class TestContextStack(unittest.TestCase):
else:
self.assertIsNone(stack.register(_exit))
for wrapper in stack._callbacks:
self.assertEqual(wrapper.__name__, _exit.__name__)
self.assertEqual(wrapper.__doc__, _exit.__doc__)
self.assertIs(wrapper.__wrapped__, _exit)
self.assertNotEqual(wrapper.__name__, _exit.__name__)
self.assertIsNone(wrapper.__doc__, _exit.__doc__)
self.assertEqual(result, expected)
def test_register_exit(self):
@ -353,11 +354,19 @@ class TestContextStack(unittest.TestCase):
self.check_exc(*exc_details)
with ContextStack() as stack:
stack.register_exit(_expect_ok)
stack.register_exit(ExitCM(_expect_ok))
self.assertIs(stack._callbacks[-1], _expect_ok)
cm = ExitCM(_expect_ok)
stack.register_exit(cm)
self.assertIs(stack._callbacks[-1].__self__, cm)
stack.register_exit(_suppress_exc)
stack.register_exit(ExitCM(_expect_exc))
self.assertIs(stack._callbacks[-1], _suppress_exc)
cm = ExitCM(_expect_exc)
stack.register_exit(cm)
self.assertIs(stack._callbacks[-1].__self__, cm)
stack.register_exit(_expect_exc)
self.assertIs(stack._callbacks[-1], _expect_exc)
stack.register_exit(_expect_exc)
self.assertIs(stack._callbacks[-1], _expect_exc)
1/0
def test_enter_context(self):
@ -368,11 +377,13 @@ class TestContextStack(unittest.TestCase):
result.append(3)
result = []
cm = TestCM()
with ContextStack() as stack:
@stack.register # Registered first => cleaned up last
def _exit():
result.append(4)
stack.enter_context(TestCM())
stack.enter_context(cm)
self.assertIs(stack._callbacks[-1].__self__, cm)
result.append(2)
self.assertEqual(result, [1, 2, 3, 4])
@ -398,6 +409,14 @@ class TestContextStack(unittest.TestCase):
new_stack.close()
self.assertEqual(result, [1, 2, 3])
def test_instance_bypass(self):
class Example(object): pass
cm = Example()
cm.__exit__ = object()
stack = ContextStack()
self.assertRaises(AttributeError, stack.enter_context, cm)
stack.register_exit(cm)
self.assertIs(stack._callbacks[-1], cm)
if __name__ == "__main__":
import unittest