From 0cbb244909322fc9a49b84af2c50d6e98b32054c Mon Sep 17 00:00:00 2001 From: Nick Coghlan <@ncoghlan> Date: Wed, 4 Jan 2012 00:26:17 +1000 Subject: [PATCH] Issue #6: fix various problems related to the callback wrappers, including correctly looking up __exit__ on the type in register_exit() --- NEWS.rst | 3 ++- contextlib2.py | 27 ++++++++++++++++++--------- test_contextlib2.py | 29 ++++++++++++++++++++++++----- 3 files changed, 44 insertions(+), 15 deletions(-) diff --git a/NEWS.rst b/NEWS.rst index fd3cde5..d8a972d 100644 --- a/NEWS.rst +++ b/NEWS.rst @@ -12,7 +12,8 @@ Release History attributes in addition to accepting exit callbacks directly * Issue #1: Add ContextStack.preserve() to move all registered callbacks to a new ContextStack object -* Wrapped callbacks now use functools.wraps to aid in introspection +* Wrapped callbacks now expose __wrapped__ (for direct callbacks) or __self__ +(for context manager methods) attributes to aid in introspection * Moved version number to a VERSION.txt file (read by both docs and setup.py) * Added NEWS.rst (and incorporated into documentation) diff --git a/contextlib2.py b/contextlib2.py index a9c24a3..6d408b1 100644 --- a/contextlib2.py +++ b/contextlib2.py @@ -164,6 +164,13 @@ class ContextStack(object): self._callbacks = deque() return new_stack + def _register_cm_exit(self, cm, cm_exit): + """Helper to correctly register callbacks to __exit__ methods""" + def _exit_wrapper(*exc_details): + return cm_exit(cm, *exc_details) + _exit_wrapper.__self__ = cm + self.register_exit(_exit_wrapper) + def register_exit(self, callback): """Registers a callback with the standard __exit__ method signature @@ -172,11 +179,13 @@ class ContextStack(object): Also accepts any object with an __exit__ method (registering the method instead of the object itself) """ + _cb_type = type(callback) try: - exit = callback.__exit__ + exit = _cb_type.__exit__ except AttributeError: - exit = callback - self._callbacks.append(exit) + self._callbacks.append(callback) + else: + self._register_cm_exit(callback, exit) return callback # Allow use as a decorator def register(self, callback, *args, **kwds): @@ -184,10 +193,12 @@ class ContextStack(object): Cannot suppress exceptions. """ - @wraps(callback) - def _wrapper(exc_type, exc, tb): + def _exit_wrapper(exc_type, exc, tb): callback(*args, **kwds) - self.register_exit(_wrapper) + # We changed the signature, so using @wraps is not appropriate, but + # setting __wrapped__ may still help with introspection + _exit_wrapper.__wrapped__ = callback + self.register_exit(_exit_wrapper) def enter_context(self, cm): """Enters the supplied context manager @@ -199,9 +210,7 @@ class ContextStack(object): _cm_type = type(cm) _exit = _cm_type.__exit__ result = _cm_type.__enter__(cm) - def _exit_wrapper(*exc_details): - return _exit(cm, *exc_details) - self.register_exit(_exit_wrapper) + self._register_cm_exit(cm, _exit) return result def close(self): diff --git a/test_contextlib2.py b/test_contextlib2.py index 8ee7bdb..f83f217 100755 --- a/test_contextlib2.py +++ b/test_contextlib2.py @@ -330,8 +330,9 @@ class TestContextStack(unittest.TestCase): else: self.assertIsNone(stack.register(_exit)) for wrapper in stack._callbacks: - self.assertEqual(wrapper.__name__, _exit.__name__) - self.assertEqual(wrapper.__doc__, _exit.__doc__) + self.assertIs(wrapper.__wrapped__, _exit) + self.assertNotEqual(wrapper.__name__, _exit.__name__) + self.assertIsNone(wrapper.__doc__, _exit.__doc__) self.assertEqual(result, expected) def test_register_exit(self): @@ -353,11 +354,19 @@ class TestContextStack(unittest.TestCase): self.check_exc(*exc_details) with ContextStack() as stack: stack.register_exit(_expect_ok) - stack.register_exit(ExitCM(_expect_ok)) + self.assertIs(stack._callbacks[-1], _expect_ok) + cm = ExitCM(_expect_ok) + stack.register_exit(cm) + self.assertIs(stack._callbacks[-1].__self__, cm) stack.register_exit(_suppress_exc) - stack.register_exit(ExitCM(_expect_exc)) + self.assertIs(stack._callbacks[-1], _suppress_exc) + cm = ExitCM(_expect_exc) + stack.register_exit(cm) + self.assertIs(stack._callbacks[-1].__self__, cm) stack.register_exit(_expect_exc) + self.assertIs(stack._callbacks[-1], _expect_exc) stack.register_exit(_expect_exc) + self.assertIs(stack._callbacks[-1], _expect_exc) 1/0 def test_enter_context(self): @@ -368,11 +377,13 @@ class TestContextStack(unittest.TestCase): result.append(3) result = [] + cm = TestCM() with ContextStack() as stack: @stack.register # Registered first => cleaned up last def _exit(): result.append(4) - stack.enter_context(TestCM()) + stack.enter_context(cm) + self.assertIs(stack._callbacks[-1].__self__, cm) result.append(2) self.assertEqual(result, [1, 2, 3, 4]) @@ -398,6 +409,14 @@ class TestContextStack(unittest.TestCase): new_stack.close() self.assertEqual(result, [1, 2, 3]) + def test_instance_bypass(self): + class Example(object): pass + cm = Example() + cm.__exit__ = object() + stack = ContextStack() + self.assertRaises(AttributeError, stack.enter_context, cm) + stack.register_exit(cm) + self.assertIs(stack._callbacks[-1], cm) if __name__ == "__main__": import unittest